mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
rebase
This commit is contained in:
parent
a5d741ec3b
commit
e5e2ffe503
@ -1,3 +1,5 @@
|
||||
.. _custom_metal_kernels:
|
||||
|
||||
Custom Metal Kernels
|
||||
====================
|
||||
|
||||
@ -76,6 +78,10 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
||||
|
||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
|
@ -161,7 +161,7 @@ A naive way to add the elements from two sets of vectors is with a loop:
|
||||
ys = mx.random.uniform(shape=(100, 4096))
|
||||
|
||||
def naive_add(xs, ys):
|
||||
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
|
||||
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
|
||||
|
||||
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
||||
|
||||
@ -169,7 +169,7 @@ Instead you can use :func:`vmap` to automatically vectorize the addition:
|
||||
|
||||
# Vectorize over the second dimension of x and the
|
||||
# first dimension of y
|
||||
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
|
||||
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
|
||||
|
||||
The ``in_axes`` parameter can be used to specify which dimensions of the
|
||||
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
||||
|
2
docs/build/html/_sources/usage/indexing.rst
vendored
2
docs/build/html/_sources/usage/indexing.rst
vendored
@ -77,7 +77,7 @@ from the GPU. Performing bounds checking for array indices before launching the
|
||||
kernel would be extremely inefficient.
|
||||
|
||||
Indexing with boolean masks is something that MLX may support in the future. In
|
||||
general, MLX has limited support for operations for which outputs
|
||||
general, MLX has limited support for operations for which output
|
||||
*shapes* are dependent on input *data*. Other examples of these types of
|
||||
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
||||
single input version of :func:`numpy.where`.
|
||||
|
@ -109,7 +109,7 @@ Here is a concrete example:
|
||||
|
||||
An important behavior to be aware of is when the graph will be implicitly
|
||||
evaluated. Anytime you ``print`` an array, convert it to an
|
||||
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
|
||||
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
|
||||
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
|
||||
saving functions) will also evaluate the array.
|
||||
|
||||
|
@ -149,7 +149,7 @@ $(function(){ initResizable(false); });
|
||||
<div class="foldopen" id="foldopen00050" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00050" name="l00050"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html#a28bafec56edec3091e8716d8ccfb6ee1"> 50</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html#a28bafec56edec3091e8716d8ccfb6ee1">~ConcurrentContext</a>() {</div>
|
||||
<div class="line"><a id="l00051" name="l00051"></a><span class="lineno"> 51</span> enc.concurrent_ = <span class="keyword">false</span>;</div>
|
||||
<div class="line"><a id="l00052" name="l00052"></a><span class="lineno"> 52</span> enc.outputs_.insert(</div>
|
||||
<div class="line"><a id="l00052" name="l00052"></a><span class="lineno"> 52</span> enc.prev_outputs_.insert(</div>
|
||||
<div class="line"><a id="l00053" name="l00053"></a><span class="lineno"> 53</span> enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());</div>
|
||||
<div class="line"><a id="l00054" name="l00054"></a><span class="lineno"> 54</span> enc.concurrent_outputs_.clear();</div>
|
||||
<div class="line"><a id="l00055" name="l00055"></a><span class="lineno"> 55</span> }</div>
|
||||
@ -170,212 +170,215 @@ $(function(){ initResizable(false); });
|
||||
<div class="line"><a id="l00066" name="l00066"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a6a2e28e542eaa2886041bddd51ff6522"> 66</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a6a2e28e542eaa2886041bddd51ff6522">set_output_array</a>(<a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& a, <span class="keywordtype">int</span> idx, int64_t offset = 0);</div>
|
||||
<div class="line"><a id="l00067" name="l00067"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a74bcd8e35f80f5a62db48c4a2bb0173e"> 67</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a74bcd8e35f80f5a62db48c4a2bb0173e">dispatchThreadgroups</a>(MTL::Size grid_dims, MTL::Size group_dims);</div>
|
||||
<div class="line"><a id="l00068" name="l00068"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a1e41477f2f489e38499f7830a91c9810"> 68</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a1e41477f2f489e38499f7830a91c9810">dispatchThreads</a>(MTL::Size grid_dims, MTL::Size group_dims);</div>
|
||||
<div class="line"><a id="l00069" name="l00069"></a><span class="lineno"> 69</span> </div>
|
||||
<div class="foldopen" id="foldopen00070" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00070" name="l00070"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a48b548a0b15f9d1279c938a1c6167034"> 70</a></span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html">ConcurrentContext</a> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a48b548a0b15f9d1279c938a1c6167034">start_concurrent</a>() {</div>
|
||||
<div class="line"><a id="l00071" name="l00071"></a><span class="lineno"> 71</span> <span class="keywordflow">return</span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html">ConcurrentContext</a>(*<span class="keyword">this</span>);</div>
|
||||
<div class="line"><a id="l00072" name="l00072"></a><span class="lineno"> 72</span> }</div>
|
||||
<div class="line"><a id="l00069" name="l00069"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ad538ae88f90560063f9ba502e2795991"> 69</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ad538ae88f90560063f9ba502e2795991">maybeInsertBarrier</a>();</div>
|
||||
<div class="line"><a id="l00070" name="l00070"></a><span class="lineno"> 70</span> </div>
|
||||
<div class="foldopen" id="foldopen00071" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00071" name="l00071"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a48b548a0b15f9d1279c938a1c6167034"> 71</a></span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html">ConcurrentContext</a> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a48b548a0b15f9d1279c938a1c6167034">start_concurrent</a>() {</div>
|
||||
<div class="line"><a id="l00072" name="l00072"></a><span class="lineno"> 72</span> <span class="keywordflow">return</span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html">ConcurrentContext</a>(*<span class="keyword">this</span>);</div>
|
||||
<div class="line"><a id="l00073" name="l00073"></a><span class="lineno"> 73</span> }</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00073" name="l00073"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a9b6dd221ccd2d939d544004cb6279198"> 73</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a9b6dd221ccd2d939d544004cb6279198">~CommandEncoder</a>();</div>
|
||||
<div class="line"><a id="l00074" name="l00074"></a><span class="lineno"> 74</span> </div>
|
||||
<div class="line"><a id="l00075" name="l00075"></a><span class="lineno"> 75</span> <span class="comment">// Inputs to all kernels in the encoder including temporaries</span></div>
|
||||
<div class="foldopen" id="foldopen00076" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00076" name="l00076"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a27ded7e54bc1712063c874646b445509"> 76</a></span> std::unordered_set<const void*>& <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a27ded7e54bc1712063c874646b445509">inputs</a>() {</div>
|
||||
<div class="line"><a id="l00077" name="l00077"></a><span class="lineno"> 77</span> <span class="keywordflow">return</span> all_inputs_;</div>
|
||||
<div class="line"><a id="l00078" name="l00078"></a><span class="lineno"> 78</span> };</div>
|
||||
<div class="line"><a id="l00074" name="l00074"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a9b6dd221ccd2d939d544004cb6279198"> 74</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a9b6dd221ccd2d939d544004cb6279198">~CommandEncoder</a>();</div>
|
||||
<div class="line"><a id="l00075" name="l00075"></a><span class="lineno"> 75</span> </div>
|
||||
<div class="line"><a id="l00076" name="l00076"></a><span class="lineno"> 76</span> <span class="comment">// Inputs to all kernels in the encoder including temporaries</span></div>
|
||||
<div class="foldopen" id="foldopen00077" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00077" name="l00077"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a27ded7e54bc1712063c874646b445509"> 77</a></span> std::unordered_set<const void*>& <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a27ded7e54bc1712063c874646b445509">inputs</a>() {</div>
|
||||
<div class="line"><a id="l00078" name="l00078"></a><span class="lineno"> 78</span> <span class="keywordflow">return</span> all_inputs_;</div>
|
||||
<div class="line"><a id="l00079" name="l00079"></a><span class="lineno"> 79</span> };</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00079" name="l00079"></a><span class="lineno"> 79</span> </div>
|
||||
<div class="line"><a id="l00080" name="l00080"></a><span class="lineno"> 80</span> <span class="comment">// Outputs of all kernels in the encoder including temporaries</span></div>
|
||||
<div class="foldopen" id="foldopen00081" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00081" name="l00081"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aefa48740fdee884f02e2d379bca4e78f"> 81</a></span> std::unordered_set<const void*> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aefa48740fdee884f02e2d379bca4e78f">outputs</a>() {</div>
|
||||
<div class="line"><a id="l00082" name="l00082"></a><span class="lineno"> 82</span> <span class="keywordflow">return</span> all_outputs_;</div>
|
||||
<div class="line"><a id="l00083" name="l00083"></a><span class="lineno"> 83</span> };</div>
|
||||
<div class="line"><a id="l00080" name="l00080"></a><span class="lineno"> 80</span> </div>
|
||||
<div class="line"><a id="l00081" name="l00081"></a><span class="lineno"> 81</span> <span class="comment">// Outputs of all kernels in the encoder including temporaries</span></div>
|
||||
<div class="foldopen" id="foldopen00082" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00082" name="l00082"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aefa48740fdee884f02e2d379bca4e78f"> 82</a></span> std::unordered_set<const void*> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aefa48740fdee884f02e2d379bca4e78f">outputs</a>() {</div>
|
||||
<div class="line"><a id="l00083" name="l00083"></a><span class="lineno"> 83</span> <span class="keywordflow">return</span> all_outputs_;</div>
|
||||
<div class="line"><a id="l00084" name="l00084"></a><span class="lineno"> 84</span> };</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00084" name="l00084"></a><span class="lineno"> 84</span> </div>
|
||||
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"> 85</span> <span class="keyword">private</span>:</div>
|
||||
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"> 86</span> MTL::ComputeCommandEncoder* enc_;</div>
|
||||
<div class="line"><a id="l00087" name="l00087"></a><span class="lineno"> 87</span> <span class="keywordtype">bool</span> concurrent_{<span class="keyword">false</span>};</div>
|
||||
<div class="line"><a id="l00088" name="l00088"></a><span class="lineno"> 88</span> std::unordered_set<MTL::Resource*> outputs_;</div>
|
||||
<div class="line"><a id="l00089" name="l00089"></a><span class="lineno"> 89</span> std::unordered_set<MTL::Resource*> concurrent_outputs_;</div>
|
||||
<div class="line"><a id="l00090" name="l00090"></a><span class="lineno"> 90</span> std::unordered_set<const void*> all_inputs_;</div>
|
||||
<div class="line"><a id="l00091" name="l00091"></a><span class="lineno"> 91</span> std::unordered_set<const void*> all_outputs_;</div>
|
||||
<div class="line"><a id="l00092" name="l00092"></a><span class="lineno"> 92</span>};</div>
|
||||
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"> 85</span> </div>
|
||||
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"> 86</span> <span class="keyword">private</span>:</div>
|
||||
<div class="line"><a id="l00087" name="l00087"></a><span class="lineno"> 87</span> MTL::ComputeCommandEncoder* enc_;</div>
|
||||
<div class="line"><a id="l00088" name="l00088"></a><span class="lineno"> 88</span> <span class="keywordtype">bool</span> needs_barrier_{<span class="keyword">false</span>};</div>
|
||||
<div class="line"><a id="l00089" name="l00089"></a><span class="lineno"> 89</span> <span class="keywordtype">bool</span> concurrent_{<span class="keyword">false</span>};</div>
|
||||
<div class="line"><a id="l00090" name="l00090"></a><span class="lineno"> 90</span> std::unordered_set<MTL::Resource*> prev_outputs_;</div>
|
||||
<div class="line"><a id="l00091" name="l00091"></a><span class="lineno"> 91</span> std::unordered_set<MTL::Resource*> next_outputs_;</div>
|
||||
<div class="line"><a id="l00092" name="l00092"></a><span class="lineno"> 92</span> std::unordered_set<MTL::Resource*> concurrent_outputs_;</div>
|
||||
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"> 93</span> std::unordered_set<const void*> all_inputs_;</div>
|
||||
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"> 94</span> std::unordered_set<const void*> all_outputs_;</div>
|
||||
<div class="line"><a id="l00095" name="l00095"></a><span class="lineno"> 95</span>};</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"> 93</span> </div>
|
||||
<div class="foldopen" id="foldopen00094" data-start="{" data-end="};">
|
||||
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_fence.html"> 94</a></span><span class="keyword">struct </span><a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_fence.html">Fence</a> {</div>
|
||||
<div class="line"><a id="l00095" name="l00095"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_fence.html#a30bee4957ae595e04922952a8010fc79"> 95</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_fence.html#a30bee4957ae595e04922952a8010fc79">Fence</a>(MTL::Fence* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>) : <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>(<a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>) {}</div>
|
||||
<div class="foldopen" id="foldopen00096" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00096" name="l00096"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_fence.html#a4940c1aece13814af7727de9abb511f2"> 96</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_fence.html#a4940c1aece13814af7727de9abb511f2">~Fence</a>() {</div>
|
||||
<div class="line"><a id="l00097" name="l00097"></a><span class="lineno"> 97</span> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>->release();</div>
|
||||
<div class="line"><a id="l00098" name="l00098"></a><span class="lineno"> 98</span> }</div>
|
||||
<div class="line"><a id="l00096" name="l00096"></a><span class="lineno"> 96</span> </div>
|
||||
<div class="foldopen" id="foldopen00097" data-start="{" data-end="};">
|
||||
<div class="line"><a id="l00097" name="l00097"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_fence.html"> 97</a></span><span class="keyword">struct </span><a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_fence.html">Fence</a> {</div>
|
||||
<div class="line"><a id="l00098" name="l00098"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_fence.html#a30bee4957ae595e04922952a8010fc79"> 98</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_fence.html#a30bee4957ae595e04922952a8010fc79">Fence</a>(MTL::Fence* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>) : <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>(<a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>) {}</div>
|
||||
<div class="foldopen" id="foldopen00099" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00099" name="l00099"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_fence.html#a4940c1aece13814af7727de9abb511f2"> 99</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_fence.html#a4940c1aece13814af7727de9abb511f2">~Fence</a>() {</div>
|
||||
<div class="line"><a id="l00100" name="l00100"></a><span class="lineno"> 100</span> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>->release();</div>
|
||||
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"> 101</span> }</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00099" name="l00099"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87"> 99</a></span> MTL::Fence* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>;</div>
|
||||
<div class="line"><a id="l00100" name="l00100"></a><span class="lineno"> 100</span>};</div>
|
||||
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87"> 102</a></span> MTL::Fence* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>;</div>
|
||||
<div class="line"><a id="l00103" name="l00103"></a><span class="lineno"> 103</span>};</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"> 101</span> </div>
|
||||
<div class="foldopen" id="foldopen00102" data-start="{" data-end="};">
|
||||
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html"> 102</a></span><span class="keyword">struct </span><a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_device_stream.html">DeviceStream</a> {</div>
|
||||
<div class="line"><a id="l00103" name="l00103"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a573326bc8b48e39076850c7bf52ad0d7"> 103</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a573326bc8b48e39076850c7bf52ad0d7">DeviceStream</a>(MTL::CommandQueue* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>) : <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>(<a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>) {};</div>
|
||||
<div class="foldopen" id="foldopen00104" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00104" name="l00104"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a1c4397732f64f5811381dd01e30e020e"> 104</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a1c4397732f64f5811381dd01e30e020e">~DeviceStream</a>() {</div>
|
||||
<div class="line"><a id="l00105" name="l00105"></a><span class="lineno"> 105</span> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>->release();</div>
|
||||
<div class="line"><a id="l00106" name="l00106"></a><span class="lineno"> 106</span> <span class="keywordflow">if</span> (<a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb">buffer</a> != <span class="keyword">nullptr</span>) {</div>
|
||||
<div class="line"><a id="l00107" name="l00107"></a><span class="lineno"> 107</span> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb">buffer</a>->release();</div>
|
||||
<div class="line"><a id="l00108" name="l00108"></a><span class="lineno"> 108</span> }</div>
|
||||
<div class="line"><a id="l00109" name="l00109"></a><span class="lineno"> 109</span> };</div>
|
||||
<div class="line"><a id="l00104" name="l00104"></a><span class="lineno"> 104</span> </div>
|
||||
<div class="foldopen" id="foldopen00105" data-start="{" data-end="};">
|
||||
<div class="line"><a id="l00105" name="l00105"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html"> 105</a></span><span class="keyword">struct </span><a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_device_stream.html">DeviceStream</a> {</div>
|
||||
<div class="line"><a id="l00106" name="l00106"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a573326bc8b48e39076850c7bf52ad0d7"> 106</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a573326bc8b48e39076850c7bf52ad0d7">DeviceStream</a>(MTL::CommandQueue* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>) : <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>(<a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>) {};</div>
|
||||
<div class="foldopen" id="foldopen00107" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00107" name="l00107"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a1c4397732f64f5811381dd01e30e020e"> 107</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a1c4397732f64f5811381dd01e30e020e">~DeviceStream</a>() {</div>
|
||||
<div class="line"><a id="l00108" name="l00108"></a><span class="lineno"> 108</span> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>->release();</div>
|
||||
<div class="line"><a id="l00109" name="l00109"></a><span class="lineno"> 109</span> <span class="keywordflow">if</span> (<a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb">buffer</a> != <span class="keyword">nullptr</span>) {</div>
|
||||
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"> 110</span> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb">buffer</a>->release();</div>
|
||||
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"> 111</span> }</div>
|
||||
<div class="line"><a id="l00112" name="l00112"></a><span class="lineno"> 112</span> };</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d"> 110</a></span> MTL::CommandQueue* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>;</div>
|
||||
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"> 111</span> <span class="comment">// A map of prior command encoder outputs to their corresponding fence</span></div>
|
||||
<div class="line"><a id="l00112" name="l00112"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a55a7a92c6abad369c99a5ede7a2521b9"> 112</a></span> std::unordered_map<const void*, std::shared_ptr<Fence>> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a55a7a92c6abad369c99a5ede7a2521b9">outputs</a>;</div>
|
||||
<div class="line"><a id="l00113" name="l00113"></a><span class="lineno"> 113</span> <span class="comment">// Used to allow thread-safe access to the outputs map</span></div>
|
||||
<div class="line"><a id="l00114" name="l00114"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a6fa08cca881fc3798ae45994a11a4fcd"> 114</a></span> std::mutex <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a6fa08cca881fc3798ae45994a11a4fcd">fence_mtx</a>;</div>
|
||||
<div class="line"><a id="l00115" name="l00115"></a><span class="lineno"> 115</span> </div>
|
||||
<div class="line"><a id="l00116" name="l00116"></a><span class="lineno"> 116</span> <span class="comment">// The buffer and buffer op count are updated</span></div>
|
||||
<div class="line"><a id="l00117" name="l00117"></a><span class="lineno"> 117</span> <span class="comment">// between command buffers</span></div>
|
||||
<div class="line"><a id="l00118" name="l00118"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb"> 118</a></span> MTL::CommandBuffer* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb">buffer</a>{<span class="keyword">nullptr</span>};</div>
|
||||
<div class="line"><a id="l00119" name="l00119"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#ab6048b329e65a59033834f3bdd351782"> 119</a></span> <span class="keywordtype">int</span> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#ab6048b329e65a59033834f3bdd351782">buffer_ops</a>{0};</div>
|
||||
<div class="line"><a id="l00120" name="l00120"></a><span class="lineno"> 120</span> </div>
|
||||
<div class="line"><a id="l00121" name="l00121"></a><span class="lineno"> 121</span> <span class="comment">// The command encoder, fence, and temporaries are updated between command</span></div>
|
||||
<div class="line"><a id="l00122" name="l00122"></a><span class="lineno"> 122</span> <span class="comment">// encoders</span></div>
|
||||
<div class="line"><a id="l00123" name="l00123"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a58e435217b9922f882507ebf48bfbbdd"> 123</a></span> std::unique_ptr<CommandEncoder> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a58e435217b9922f882507ebf48bfbbdd">encoder</a>{<span class="keyword">nullptr</span>};</div>
|
||||
<div class="line"><a id="l00124" name="l00124"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a876199de8da1efa9a362451029638499"> 124</a></span> std::shared_ptr<Fence> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a876199de8da1efa9a362451029638499">fence</a>;</div>
|
||||
<div class="line"><a id="l00125" name="l00125"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#aee88009117dfff1ad121eabe28d5f3de"> 125</a></span> std::vector<array> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#aee88009117dfff1ad121eabe28d5f3de">temporaries</a>;</div>
|
||||
<div class="line"><a id="l00126" name="l00126"></a><span class="lineno"> 126</span>};</div>
|
||||
<div class="line"><a id="l00113" name="l00113"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d"> 113</a></span> MTL::CommandQueue* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>;</div>
|
||||
<div class="line"><a id="l00114" name="l00114"></a><span class="lineno"> 114</span> <span class="comment">// A map of prior command encoder outputs to their corresponding fence</span></div>
|
||||
<div class="line"><a id="l00115" name="l00115"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a55a7a92c6abad369c99a5ede7a2521b9"> 115</a></span> std::unordered_map<const void*, std::shared_ptr<Fence>> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a55a7a92c6abad369c99a5ede7a2521b9">outputs</a>;</div>
|
||||
<div class="line"><a id="l00116" name="l00116"></a><span class="lineno"> 116</span> <span class="comment">// Used to allow thread-safe access to the outputs map</span></div>
|
||||
<div class="line"><a id="l00117" name="l00117"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a6fa08cca881fc3798ae45994a11a4fcd"> 117</a></span> std::mutex <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a6fa08cca881fc3798ae45994a11a4fcd">fence_mtx</a>;</div>
|
||||
<div class="line"><a id="l00118" name="l00118"></a><span class="lineno"> 118</span> </div>
|
||||
<div class="line"><a id="l00119" name="l00119"></a><span class="lineno"> 119</span> <span class="comment">// The buffer and buffer op count are updated</span></div>
|
||||
<div class="line"><a id="l00120" name="l00120"></a><span class="lineno"> 120</span> <span class="comment">// between command buffers</span></div>
|
||||
<div class="line"><a id="l00121" name="l00121"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb"> 121</a></span> MTL::CommandBuffer* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb">buffer</a>{<span class="keyword">nullptr</span>};</div>
|
||||
<div class="line"><a id="l00122" name="l00122"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#ab6048b329e65a59033834f3bdd351782"> 122</a></span> <span class="keywordtype">int</span> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#ab6048b329e65a59033834f3bdd351782">buffer_ops</a>{0};</div>
|
||||
<div class="line"><a id="l00123" name="l00123"></a><span class="lineno"> 123</span> </div>
|
||||
<div class="line"><a id="l00124" name="l00124"></a><span class="lineno"> 124</span> <span class="comment">// The command encoder, fence, and temporaries are updated between command</span></div>
|
||||
<div class="line"><a id="l00125" name="l00125"></a><span class="lineno"> 125</span> <span class="comment">// encoders</span></div>
|
||||
<div class="line"><a id="l00126" name="l00126"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a58e435217b9922f882507ebf48bfbbdd"> 126</a></span> std::unique_ptr<CommandEncoder> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a58e435217b9922f882507ebf48bfbbdd">encoder</a>{<span class="keyword">nullptr</span>};</div>
|
||||
<div class="line"><a id="l00127" name="l00127"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a876199de8da1efa9a362451029638499"> 127</a></span> std::shared_ptr<Fence> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a876199de8da1efa9a362451029638499">fence</a>;</div>
|
||||
<div class="line"><a id="l00128" name="l00128"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#aee88009117dfff1ad121eabe28d5f3de"> 128</a></span> std::vector<array> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#aee88009117dfff1ad121eabe28d5f3de">temporaries</a>;</div>
|
||||
<div class="line"><a id="l00129" name="l00129"></a><span class="lineno"> 129</span>};</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00127" name="l00127"></a><span class="lineno"> 127</span> </div>
|
||||
<div class="foldopen" id="foldopen00128" data-start="{" data-end="};">
|
||||
<div class="line"><a id="l00128" name="l00128"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html"> 128</a></span><span class="keyword">class </span><a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a> {</div>
|
||||
<div class="line"><a id="l00129" name="l00129"></a><span class="lineno"> 129</span> <span class="keyword">public</span>:</div>
|
||||
<div class="line"><a id="l00130" name="l00130"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#ae0db74570eb4b19d8cf19774db91bfd6"> 130</a></span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#ae0db74570eb4b19d8cf19774db91bfd6">Device</a>();</div>
|
||||
<div class="line"><a id="l00131" name="l00131"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#abf59a4addb5473f9e814e3651ba85f06"> 131</a></span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#abf59a4addb5473f9e814e3651ba85f06">Device</a>(<span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a>&) = <span class="keyword">delete</span>;</div>
|
||||
<div class="line"><a id="l00132" name="l00132"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#ad1d6382fd18a46b1906e1b43e0bd2e73"> 132</a></span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a>& <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#ad1d6382fd18a46b1906e1b43e0bd2e73">operator=</a>(<span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a>&) = <span class="keyword">delete</span>;</div>
|
||||
<div class="line"><a id="l00133" name="l00133"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a4f39c28c6cdd1d2da1918f5871bcba6e"> 133</a></span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a4f39c28c6cdd1d2da1918f5871bcba6e">~Device</a>();</div>
|
||||
<div class="line"><a id="l00134" name="l00134"></a><span class="lineno"> 134</span> </div>
|
||||
<div class="foldopen" id="foldopen00135" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00135" name="l00135"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653"> 135</a></span> MTL::Device* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653">mtl_device</a>() {</div>
|
||||
<div class="line"><a id="l00136" name="l00136"></a><span class="lineno"> 136</span> <span class="keywordflow">return</span> device_;</div>
|
||||
<div class="line"><a id="l00137" name="l00137"></a><span class="lineno"> 137</span> };</div>
|
||||
<div class="line"><a id="l00130" name="l00130"></a><span class="lineno"> 130</span> </div>
|
||||
<div class="foldopen" id="foldopen00131" data-start="{" data-end="};">
|
||||
<div class="line"><a id="l00131" name="l00131"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html"> 131</a></span><span class="keyword">class </span><a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a> {</div>
|
||||
<div class="line"><a id="l00132" name="l00132"></a><span class="lineno"> 132</span> <span class="keyword">public</span>:</div>
|
||||
<div class="line"><a id="l00133" name="l00133"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#ae0db74570eb4b19d8cf19774db91bfd6"> 133</a></span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#ae0db74570eb4b19d8cf19774db91bfd6">Device</a>();</div>
|
||||
<div class="line"><a id="l00134" name="l00134"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#abf59a4addb5473f9e814e3651ba85f06"> 134</a></span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#abf59a4addb5473f9e814e3651ba85f06">Device</a>(<span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a>&) = <span class="keyword">delete</span>;</div>
|
||||
<div class="line"><a id="l00135" name="l00135"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#ad1d6382fd18a46b1906e1b43e0bd2e73"> 135</a></span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a>& <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#ad1d6382fd18a46b1906e1b43e0bd2e73">operator=</a>(<span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a>&) = <span class="keyword">delete</span>;</div>
|
||||
<div class="line"><a id="l00136" name="l00136"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a4f39c28c6cdd1d2da1918f5871bcba6e"> 136</a></span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a4f39c28c6cdd1d2da1918f5871bcba6e">~Device</a>();</div>
|
||||
<div class="line"><a id="l00137" name="l00137"></a><span class="lineno"> 137</span> </div>
|
||||
<div class="foldopen" id="foldopen00138" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00138" name="l00138"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653"> 138</a></span> MTL::Device* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653">mtl_device</a>() {</div>
|
||||
<div class="line"><a id="l00139" name="l00139"></a><span class="lineno"> 139</span> <span class="keywordflow">return</span> device_;</div>
|
||||
<div class="line"><a id="l00140" name="l00140"></a><span class="lineno"> 140</span> };</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00138" name="l00138"></a><span class="lineno"> 138</span> </div>
|
||||
<div class="foldopen" id="foldopen00139" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00139" name="l00139"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a65f64dd8bafdc704d871fc5be5e7bc0b"> 139</a></span> <span class="keyword">const</span> std::string& <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a65f64dd8bafdc704d871fc5be5e7bc0b">get_architecture</a>() {</div>
|
||||
<div class="line"><a id="l00140" name="l00140"></a><span class="lineno"> 140</span> <span class="keywordflow">return</span> arch_;</div>
|
||||
<div class="line"><a id="l00141" name="l00141"></a><span class="lineno"> 141</span> }</div>
|
||||
<div class="line"><a id="l00141" name="l00141"></a><span class="lineno"> 141</span> </div>
|
||||
<div class="foldopen" id="foldopen00142" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00142" name="l00142"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a65f64dd8bafdc704d871fc5be5e7bc0b"> 142</a></span> <span class="keyword">const</span> std::string& <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a65f64dd8bafdc704d871fc5be5e7bc0b">get_architecture</a>() {</div>
|
||||
<div class="line"><a id="l00143" name="l00143"></a><span class="lineno"> 143</span> <span class="keywordflow">return</span> arch_;</div>
|
||||
<div class="line"><a id="l00144" name="l00144"></a><span class="lineno"> 144</span> }</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00142" name="l00142"></a><span class="lineno"> 142</span> </div>
|
||||
<div class="line"><a id="l00143" name="l00143"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a8135ae2a8c1e6f3861e84d4e60c28b67"> 143</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a8135ae2a8c1e6f3861e84d4e60c28b67">new_queue</a>(<span class="keywordtype">int</span> index);</div>
|
||||
<div class="line"><a id="l00144" name="l00144"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a5fe3970fbe92ccc55fce4241ffbe5210"> 144</a></span> MTL::CommandBuffer* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a5fe3970fbe92ccc55fce4241ffbe5210">get_command_buffer</a>(<span class="keywordtype">int</span> index);</div>
|
||||
<div class="line"><a id="l00145" name="l00145"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a064e1cb6a16de7a0619f6447622350f8"> 145</a></span> <span class="keywordtype">int</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a064e1cb6a16de7a0619f6447622350f8">get_command_buffer_ops</a>(<span class="keywordtype">int</span> index);</div>
|
||||
<div class="line"><a id="l00146" name="l00146"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a7a33d4d601423a3d3c23d5ad7072abb6"> 146</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a7a33d4d601423a3d3c23d5ad7072abb6">increment_command_buffer_ops</a>(<span class="keywordtype">int</span> index);</div>
|
||||
<div class="line"><a id="l00147" name="l00147"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a95248f1387824067fd4fed23ace5ac0c"> 147</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a95248f1387824067fd4fed23ace5ac0c">commit_command_buffer</a>(<span class="keywordtype">int</span> index);</div>
|
||||
<div class="line"><a id="l00148" name="l00148"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#affa682ef612def4890f5152f81ffb7e6"> 148</a></span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a>& <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#affa682ef612def4890f5152f81ffb7e6">get_command_encoder</a>(<span class="keywordtype">int</span> index);</div>
|
||||
<div class="line"><a id="l00149" name="l00149"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a60689f97347811b27e8c5ca23e0372bf"> 149</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a60689f97347811b27e8c5ca23e0372bf">end_encoding</a>(<span class="keywordtype">int</span> index);</div>
|
||||
<div class="line"><a id="l00150" name="l00150"></a><span class="lineno"> 150</span> </div>
|
||||
<div class="line"><a id="l00151" name="l00151"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a45945f2efcd242d915ffa2171e92bf9d"> 151</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a45945f2efcd242d915ffa2171e92bf9d">register_library</a>(</div>
|
||||
<div class="line"><a id="l00152" name="l00152"></a><span class="lineno"> 152</span> <span class="keyword">const</span> std::string& lib_name,</div>
|
||||
<div class="line"><a id="l00153" name="l00153"></a><span class="lineno"> 153</span> <span class="keyword">const</span> std::string& lib_path);</div>
|
||||
<div class="line"><a id="l00154" name="l00154"></a><span class="lineno"> 154</span> </div>
|
||||
<div class="line"><a id="l00155" name="l00155"></a><span class="lineno"> 155</span> <span class="comment">// Note, this should remain in the header so that it is not dynamically</span></div>
|
||||
<div class="line"><a id="l00156" name="l00156"></a><span class="lineno"> 156</span> <span class="comment">// linked</span></div>
|
||||
<div class="foldopen" id="foldopen00157" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00157" name="l00157"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a99ff72689b7beb65ad4541391b0eeabf"> 157</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a99ff72689b7beb65ad4541391b0eeabf">register_library</a>(<span class="keyword">const</span> std::string& lib_name) {</div>
|
||||
<div class="line"><a id="l00158" name="l00158"></a><span class="lineno"> 158</span> <span class="keywordflow">if</span> (<span class="keyword">auto</span> it = library_map_.find(lib_name); it == library_map_.end()) {</div>
|
||||
<div class="line"><a id="l00159" name="l00159"></a><span class="lineno"> 159</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a45945f2efcd242d915ffa2171e92bf9d">register_library</a>(lib_name, <a class="code hl_function" href="namespacemlx_1_1core_1_1metal.html#a5fd6ba2040e53a254b9d71ae7ebd315f">get_colocated_mtllib_path</a>(lib_name));</div>
|
||||
<div class="line"><a id="l00160" name="l00160"></a><span class="lineno"> 160</span> }</div>
|
||||
<div class="line"><a id="l00161" name="l00161"></a><span class="lineno"> 161</span> }</div>
|
||||
<div class="line"><a id="l00145" name="l00145"></a><span class="lineno"> 145</span> </div>
|
||||
<div class="line"><a id="l00146" name="l00146"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a8135ae2a8c1e6f3861e84d4e60c28b67"> 146</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a8135ae2a8c1e6f3861e84d4e60c28b67">new_queue</a>(<span class="keywordtype">int</span> index);</div>
|
||||
<div class="line"><a id="l00147" name="l00147"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a5fe3970fbe92ccc55fce4241ffbe5210"> 147</a></span> MTL::CommandBuffer* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a5fe3970fbe92ccc55fce4241ffbe5210">get_command_buffer</a>(<span class="keywordtype">int</span> index);</div>
|
||||
<div class="line"><a id="l00148" name="l00148"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a064e1cb6a16de7a0619f6447622350f8"> 148</a></span> <span class="keywordtype">int</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a064e1cb6a16de7a0619f6447622350f8">get_command_buffer_ops</a>(<span class="keywordtype">int</span> index);</div>
|
||||
<div class="line"><a id="l00149" name="l00149"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a7a33d4d601423a3d3c23d5ad7072abb6"> 149</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a7a33d4d601423a3d3c23d5ad7072abb6">increment_command_buffer_ops</a>(<span class="keywordtype">int</span> index);</div>
|
||||
<div class="line"><a id="l00150" name="l00150"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a95248f1387824067fd4fed23ace5ac0c"> 150</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a95248f1387824067fd4fed23ace5ac0c">commit_command_buffer</a>(<span class="keywordtype">int</span> index);</div>
|
||||
<div class="line"><a id="l00151" name="l00151"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#affa682ef612def4890f5152f81ffb7e6"> 151</a></span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a>& <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#affa682ef612def4890f5152f81ffb7e6">get_command_encoder</a>(<span class="keywordtype">int</span> index);</div>
|
||||
<div class="line"><a id="l00152" name="l00152"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a60689f97347811b27e8c5ca23e0372bf"> 152</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a60689f97347811b27e8c5ca23e0372bf">end_encoding</a>(<span class="keywordtype">int</span> index);</div>
|
||||
<div class="line"><a id="l00153" name="l00153"></a><span class="lineno"> 153</span> </div>
|
||||
<div class="line"><a id="l00154" name="l00154"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a45945f2efcd242d915ffa2171e92bf9d"> 154</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a45945f2efcd242d915ffa2171e92bf9d">register_library</a>(</div>
|
||||
<div class="line"><a id="l00155" name="l00155"></a><span class="lineno"> 155</span> <span class="keyword">const</span> std::string& lib_name,</div>
|
||||
<div class="line"><a id="l00156" name="l00156"></a><span class="lineno"> 156</span> <span class="keyword">const</span> std::string& lib_path);</div>
|
||||
<div class="line"><a id="l00157" name="l00157"></a><span class="lineno"> 157</span> </div>
|
||||
<div class="line"><a id="l00158" name="l00158"></a><span class="lineno"> 158</span> <span class="comment">// Note, this should remain in the header so that it is not dynamically</span></div>
|
||||
<div class="line"><a id="l00159" name="l00159"></a><span class="lineno"> 159</span> <span class="comment">// linked</span></div>
|
||||
<div class="foldopen" id="foldopen00160" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00160" name="l00160"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a99ff72689b7beb65ad4541391b0eeabf"> 160</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a99ff72689b7beb65ad4541391b0eeabf">register_library</a>(<span class="keyword">const</span> std::string& lib_name) {</div>
|
||||
<div class="line"><a id="l00161" name="l00161"></a><span class="lineno"> 161</span> <span class="keywordflow">if</span> (<span class="keyword">auto</span> it = library_map_.find(lib_name); it == library_map_.end()) {</div>
|
||||
<div class="line"><a id="l00162" name="l00162"></a><span class="lineno"> 162</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a45945f2efcd242d915ffa2171e92bf9d">register_library</a>(lib_name, <a class="code hl_function" href="namespacemlx_1_1core_1_1metal.html#a5fd6ba2040e53a254b9d71ae7ebd315f">get_colocated_mtllib_path</a>(lib_name));</div>
|
||||
<div class="line"><a id="l00163" name="l00163"></a><span class="lineno"> 163</span> }</div>
|
||||
<div class="line"><a id="l00164" name="l00164"></a><span class="lineno"> 164</span> }</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00162" name="l00162"></a><span class="lineno"> 162</span> </div>
|
||||
<div class="line"><a id="l00163" name="l00163"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a75ed55e73baf48013028796518723ff0"> 163</a></span> MTL::Library* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a75ed55e73baf48013028796518723ff0">get_library</a>(</div>
|
||||
<div class="line"><a id="l00164" name="l00164"></a><span class="lineno"> 164</span> <span class="keyword">const</span> std::string& name,</div>
|
||||
<div class="line"><a id="l00165" name="l00165"></a><span class="lineno"> 165</span> <span class="keyword">const</span> std::function<std::string(<span class="keywordtype">void</span>)>& builder);</div>
|
||||
<div class="line"><a id="l00166" name="l00166"></a><span class="lineno"> 166</span> </div>
|
||||
<div class="line"><a id="l00167" name="l00167"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a6810c4dcbcfbf93fc51d42aa5ff0fc3a"> 167</a></span> MTL::ComputePipelineState* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a6810c4dcbcfbf93fc51d42aa5ff0fc3a">get_kernel</a>(</div>
|
||||
<div class="line"><a id="l00168" name="l00168"></a><span class="lineno"> 168</span> <span class="keyword">const</span> std::string& base_name,</div>
|
||||
<div class="line"><a id="l00169" name="l00169"></a><span class="lineno"> 169</span> MTL::Library* mtl_lib,</div>
|
||||
<div class="line"><a id="l00170" name="l00170"></a><span class="lineno"> 170</span> <span class="keyword">const</span> std::string& hash_name = <span class="stringliteral">""</span>,</div>
|
||||
<div class="line"><a id="l00171" name="l00171"></a><span class="lineno"> 171</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">MTLFCList</a>& func_consts = {},</div>
|
||||
<div class="line"><a id="l00172" name="l00172"></a><span class="lineno"> 172</span> <span class="keyword">const</span> std::vector<MTL::Function*>& linked_functions = {});</div>
|
||||
<div class="line"><a id="l00173" name="l00173"></a><span class="lineno"> 173</span> </div>
|
||||
<div class="line"><a id="l00174" name="l00174"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#afa0cac9d800c21a8a7f6cb224256abaf"> 174</a></span> MTL::ComputePipelineState* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#afa0cac9d800c21a8a7f6cb224256abaf">get_kernel</a>(</div>
|
||||
<div class="line"><a id="l00175" name="l00175"></a><span class="lineno"> 175</span> <span class="keyword">const</span> std::string& base_name,</div>
|
||||
<div class="line"><a id="l00176" name="l00176"></a><span class="lineno"> 176</span> <span class="keyword">const</span> std::string& lib_name = <span class="stringliteral">"mlx"</span>,</div>
|
||||
<div class="line"><a id="l00177" name="l00177"></a><span class="lineno"> 177</span> <span class="keyword">const</span> std::string& hash_name = <span class="stringliteral">""</span>,</div>
|
||||
<div class="line"><a id="l00178" name="l00178"></a><span class="lineno"> 178</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">MTLFCList</a>& func_consts = {},</div>
|
||||
<div class="line"><a id="l00179" name="l00179"></a><span class="lineno"> 179</span> <span class="keyword">const</span> std::vector<MTL::Function*>& linked_functions = {});</div>
|
||||
<div class="line"><a id="l00180" name="l00180"></a><span class="lineno"> 180</span> </div>
|
||||
<div class="line"><a id="l00181" name="l00181"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a6e33e2b1287324fb4a6575e0da5e5881"> 181</a></span> MTL::ArgumentEncoder* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a6e33e2b1287324fb4a6575e0da5e5881">argument_encoder</a>(</div>
|
||||
<div class="line"><a id="l00182" name="l00182"></a><span class="lineno"> 182</span> <span class="keyword">const</span> std::vector<MTL::ArgumentDescriptor*>& arg_descs) <span class="keyword">const</span>;</div>
|
||||
<div class="line"><a id="l00165" name="l00165"></a><span class="lineno"> 165</span> </div>
|
||||
<div class="line"><a id="l00166" name="l00166"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a75ed55e73baf48013028796518723ff0"> 166</a></span> MTL::Library* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a75ed55e73baf48013028796518723ff0">get_library</a>(</div>
|
||||
<div class="line"><a id="l00167" name="l00167"></a><span class="lineno"> 167</span> <span class="keyword">const</span> std::string& name,</div>
|
||||
<div class="line"><a id="l00168" name="l00168"></a><span class="lineno"> 168</span> <span class="keyword">const</span> std::function<std::string(<span class="keywordtype">void</span>)>& builder);</div>
|
||||
<div class="line"><a id="l00169" name="l00169"></a><span class="lineno"> 169</span> </div>
|
||||
<div class="line"><a id="l00170" name="l00170"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a6810c4dcbcfbf93fc51d42aa5ff0fc3a"> 170</a></span> MTL::ComputePipelineState* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a6810c4dcbcfbf93fc51d42aa5ff0fc3a">get_kernel</a>(</div>
|
||||
<div class="line"><a id="l00171" name="l00171"></a><span class="lineno"> 171</span> <span class="keyword">const</span> std::string& base_name,</div>
|
||||
<div class="line"><a id="l00172" name="l00172"></a><span class="lineno"> 172</span> MTL::Library* mtl_lib,</div>
|
||||
<div class="line"><a id="l00173" name="l00173"></a><span class="lineno"> 173</span> <span class="keyword">const</span> std::string& hash_name = <span class="stringliteral">""</span>,</div>
|
||||
<div class="line"><a id="l00174" name="l00174"></a><span class="lineno"> 174</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">MTLFCList</a>& func_consts = {},</div>
|
||||
<div class="line"><a id="l00175" name="l00175"></a><span class="lineno"> 175</span> <span class="keyword">const</span> std::vector<MTL::Function*>& linked_functions = {});</div>
|
||||
<div class="line"><a id="l00176" name="l00176"></a><span class="lineno"> 176</span> </div>
|
||||
<div class="line"><a id="l00177" name="l00177"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#afa0cac9d800c21a8a7f6cb224256abaf"> 177</a></span> MTL::ComputePipelineState* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#afa0cac9d800c21a8a7f6cb224256abaf">get_kernel</a>(</div>
|
||||
<div class="line"><a id="l00178" name="l00178"></a><span class="lineno"> 178</span> <span class="keyword">const</span> std::string& base_name,</div>
|
||||
<div class="line"><a id="l00179" name="l00179"></a><span class="lineno"> 179</span> <span class="keyword">const</span> std::string& lib_name = <span class="stringliteral">"mlx"</span>,</div>
|
||||
<div class="line"><a id="l00180" name="l00180"></a><span class="lineno"> 180</span> <span class="keyword">const</span> std::string& hash_name = <span class="stringliteral">""</span>,</div>
|
||||
<div class="line"><a id="l00181" name="l00181"></a><span class="lineno"> 181</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">MTLFCList</a>& func_consts = {},</div>
|
||||
<div class="line"><a id="l00182" name="l00182"></a><span class="lineno"> 182</span> <span class="keyword">const</span> std::vector<MTL::Function*>& linked_functions = {});</div>
|
||||
<div class="line"><a id="l00183" name="l00183"></a><span class="lineno"> 183</span> </div>
|
||||
<div class="line"><a id="l00184" name="l00184"></a><span class="lineno"> 184</span> <span class="comment">// Record temporary arrays for the given stream index</span></div>
|
||||
<div class="line"><a id="l00185" name="l00185"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#acb90010af0cffe27fd8cc6c253d3a576"> 185</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#acb90010af0cffe27fd8cc6c253d3a576">add_temporary</a>(<a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a> arr, <span class="keywordtype">int</span> index);</div>
|
||||
<div class="line"><a id="l00186" name="l00186"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a72ad17c96fc6ce825bc77f0bed657901"> 186</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a72ad17c96fc6ce825bc77f0bed657901">add_temporaries</a>(std::vector<array> arrays, <span class="keywordtype">int</span> index);</div>
|
||||
<div class="line"><a id="l00187" name="l00187"></a><span class="lineno"> 187</span> </div>
|
||||
<div class="line"><a id="l00188" name="l00188"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a03a2f0c712660a1bd437cb16e4aba79f"> 188</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a03a2f0c712660a1bd437cb16e4aba79f">set_residency_set</a>(<span class="keyword">const</span> MTL::ResidencySet* residency_set);</div>
|
||||
<div class="line"><a id="l00189" name="l00189"></a><span class="lineno"> 189</span> </div>
|
||||
<div class="line"><a id="l00190" name="l00190"></a><span class="lineno"> 190</span> <span class="keyword">private</span>:</div>
|
||||
<div class="line"><a id="l00191" name="l00191"></a><span class="lineno"> 191</span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_device_stream.html">DeviceStream</a>& get_stream_(<span class="keywordtype">int</span> index) {</div>
|
||||
<div class="line"><a id="l00192" name="l00192"></a><span class="lineno"> 192</span> <span class="keywordflow">return</span> stream_map_.find(index)->second;</div>
|
||||
<div class="line"><a id="l00193" name="l00193"></a><span class="lineno"> 193</span> }</div>
|
||||
<div class="line"><a id="l00194" name="l00194"></a><span class="lineno"> 194</span> MTL::Library* get_library_cache_(<span class="keyword">const</span> std::string& name);</div>
|
||||
<div class="line"><a id="l00195" name="l00195"></a><span class="lineno"> 195</span> </div>
|
||||
<div class="line"><a id="l00196" name="l00196"></a><span class="lineno"> 196</span> MTL::Library* get_library_(<span class="keyword">const</span> std::string& name);</div>
|
||||
<div class="line"><a id="l00197" name="l00197"></a><span class="lineno"> 197</span> MTL::Library* build_library_(<span class="keyword">const</span> std::string& source_string);</div>
|
||||
<div class="line"><a id="l00184" name="l00184"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a6e33e2b1287324fb4a6575e0da5e5881"> 184</a></span> MTL::ArgumentEncoder* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a6e33e2b1287324fb4a6575e0da5e5881">argument_encoder</a>(</div>
|
||||
<div class="line"><a id="l00185" name="l00185"></a><span class="lineno"> 185</span> <span class="keyword">const</span> std::vector<MTL::ArgumentDescriptor*>& arg_descs) <span class="keyword">const</span>;</div>
|
||||
<div class="line"><a id="l00186" name="l00186"></a><span class="lineno"> 186</span> </div>
|
||||
<div class="line"><a id="l00187" name="l00187"></a><span class="lineno"> 187</span> <span class="comment">// Record temporary arrays for the given stream index</span></div>
|
||||
<div class="line"><a id="l00188" name="l00188"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#acb90010af0cffe27fd8cc6c253d3a576"> 188</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#acb90010af0cffe27fd8cc6c253d3a576">add_temporary</a>(<a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a> arr, <span class="keywordtype">int</span> index);</div>
|
||||
<div class="line"><a id="l00189" name="l00189"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a72ad17c96fc6ce825bc77f0bed657901"> 189</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a72ad17c96fc6ce825bc77f0bed657901">add_temporaries</a>(std::vector<array> arrays, <span class="keywordtype">int</span> index);</div>
|
||||
<div class="line"><a id="l00190" name="l00190"></a><span class="lineno"> 190</span> </div>
|
||||
<div class="line"><a id="l00191" name="l00191"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a03a2f0c712660a1bd437cb16e4aba79f"> 191</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a03a2f0c712660a1bd437cb16e4aba79f">set_residency_set</a>(<span class="keyword">const</span> MTL::ResidencySet* residency_set);</div>
|
||||
<div class="line"><a id="l00192" name="l00192"></a><span class="lineno"> 192</span> </div>
|
||||
<div class="line"><a id="l00193" name="l00193"></a><span class="lineno"> 193</span> <span class="keyword">private</span>:</div>
|
||||
<div class="line"><a id="l00194" name="l00194"></a><span class="lineno"> 194</span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_device_stream.html">DeviceStream</a>& get_stream_(<span class="keywordtype">int</span> index) {</div>
|
||||
<div class="line"><a id="l00195" name="l00195"></a><span class="lineno"> 195</span> <span class="keywordflow">return</span> stream_map_.find(index)->second;</div>
|
||||
<div class="line"><a id="l00196" name="l00196"></a><span class="lineno"> 196</span> }</div>
|
||||
<div class="line"><a id="l00197" name="l00197"></a><span class="lineno"> 197</span> MTL::Library* get_library_cache_(<span class="keyword">const</span> std::string& name);</div>
|
||||
<div class="line"><a id="l00198" name="l00198"></a><span class="lineno"> 198</span> </div>
|
||||
<div class="line"><a id="l00199" name="l00199"></a><span class="lineno"> 199</span> MTL::Function* get_function_(<span class="keyword">const</span> std::string& name, MTL::Library* mtl_lib);</div>
|
||||
<div class="line"><a id="l00200" name="l00200"></a><span class="lineno"> 200</span> </div>
|
||||
<div class="line"><a id="l00201" name="l00201"></a><span class="lineno"> 201</span> MTL::Function* get_function_(</div>
|
||||
<div class="line"><a id="l00202" name="l00202"></a><span class="lineno"> 202</span> <span class="keyword">const</span> std::string& name,</div>
|
||||
<div class="line"><a id="l00203" name="l00203"></a><span class="lineno"> 203</span> <span class="keyword">const</span> std::string& specialized_name,</div>
|
||||
<div class="line"><a id="l00204" name="l00204"></a><span class="lineno"> 204</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">MTLFCList</a>& func_consts,</div>
|
||||
<div class="line"><a id="l00205" name="l00205"></a><span class="lineno"> 205</span> MTL::Library* mtl_lib);</div>
|
||||
<div class="line"><a id="l00206" name="l00206"></a><span class="lineno"> 206</span> </div>
|
||||
<div class="line"><a id="l00207" name="l00207"></a><span class="lineno"> 207</span> MTL::LinkedFunctions* get_linked_functions_(</div>
|
||||
<div class="line"><a id="l00208" name="l00208"></a><span class="lineno"> 208</span> <span class="keyword">const</span> std::vector<MTL::Function*>& funcs);</div>
|
||||
<div class="line"><a id="l00199" name="l00199"></a><span class="lineno"> 199</span> MTL::Library* get_library_(<span class="keyword">const</span> std::string& name);</div>
|
||||
<div class="line"><a id="l00200" name="l00200"></a><span class="lineno"> 200</span> MTL::Library* build_library_(<span class="keyword">const</span> std::string& source_string);</div>
|
||||
<div class="line"><a id="l00201" name="l00201"></a><span class="lineno"> 201</span> </div>
|
||||
<div class="line"><a id="l00202" name="l00202"></a><span class="lineno"> 202</span> MTL::Function* get_function_(<span class="keyword">const</span> std::string& name, MTL::Library* mtl_lib);</div>
|
||||
<div class="line"><a id="l00203" name="l00203"></a><span class="lineno"> 203</span> </div>
|
||||
<div class="line"><a id="l00204" name="l00204"></a><span class="lineno"> 204</span> MTL::Function* get_function_(</div>
|
||||
<div class="line"><a id="l00205" name="l00205"></a><span class="lineno"> 205</span> <span class="keyword">const</span> std::string& name,</div>
|
||||
<div class="line"><a id="l00206" name="l00206"></a><span class="lineno"> 206</span> <span class="keyword">const</span> std::string& specialized_name,</div>
|
||||
<div class="line"><a id="l00207" name="l00207"></a><span class="lineno"> 207</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">MTLFCList</a>& func_consts,</div>
|
||||
<div class="line"><a id="l00208" name="l00208"></a><span class="lineno"> 208</span> MTL::Library* mtl_lib);</div>
|
||||
<div class="line"><a id="l00209" name="l00209"></a><span class="lineno"> 209</span> </div>
|
||||
<div class="line"><a id="l00210" name="l00210"></a><span class="lineno"> 210</span> MTL::ComputePipelineState* get_kernel_(</div>
|
||||
<div class="line"><a id="l00211" name="l00211"></a><span class="lineno"> 211</span> <span class="keyword">const</span> std::string& name,</div>
|
||||
<div class="line"><a id="l00212" name="l00212"></a><span class="lineno"> 212</span> <span class="keyword">const</span> MTL::Function* mtl_function);</div>
|
||||
<div class="line"><a id="l00213" name="l00213"></a><span class="lineno"> 213</span> </div>
|
||||
<div class="line"><a id="l00214" name="l00214"></a><span class="lineno"> 214</span> MTL::ComputePipelineState* get_kernel_(</div>
|
||||
<div class="line"><a id="l00215" name="l00215"></a><span class="lineno"> 215</span> <span class="keyword">const</span> std::string& name,</div>
|
||||
<div class="line"><a id="l00216" name="l00216"></a><span class="lineno"> 216</span> <span class="keyword">const</span> MTL::Function* mtl_function,</div>
|
||||
<div class="line"><a id="l00217" name="l00217"></a><span class="lineno"> 217</span> <span class="keyword">const</span> MTL::LinkedFunctions* linked_functions);</div>
|
||||
<div class="line"><a id="l00218" name="l00218"></a><span class="lineno"> 218</span> </div>
|
||||
<div class="line"><a id="l00219" name="l00219"></a><span class="lineno"> 219</span> MTL::ComputePipelineState* get_kernel_(</div>
|
||||
<div class="line"><a id="l00220" name="l00220"></a><span class="lineno"> 220</span> <span class="keyword">const</span> std::string& base_name,</div>
|
||||
<div class="line"><a id="l00221" name="l00221"></a><span class="lineno"> 221</span> MTL::Library* mtl_lib,</div>
|
||||
<div class="line"><a id="l00222" name="l00222"></a><span class="lineno"> 222</span> <span class="keyword">const</span> std::string& hash_name,</div>
|
||||
<div class="line"><a id="l00223" name="l00223"></a><span class="lineno"> 223</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">MTLFCList</a>& func_consts = {},</div>
|
||||
<div class="line"><a id="l00224" name="l00224"></a><span class="lineno"> 224</span> <span class="keyword">const</span> std::vector<MTL::Function*>& linked_functions = {});</div>
|
||||
<div class="line"><a id="l00225" name="l00225"></a><span class="lineno"> 225</span> </div>
|
||||
<div class="line"><a id="l00226" name="l00226"></a><span class="lineno"> 226</span> MTL::Device* device_;</div>
|
||||
<div class="line"><a id="l00227" name="l00227"></a><span class="lineno"> 227</span> std::unordered_map<int32_t, DeviceStream> stream_map_;</div>
|
||||
<div class="line"><a id="l00210" name="l00210"></a><span class="lineno"> 210</span> MTL::LinkedFunctions* get_linked_functions_(</div>
|
||||
<div class="line"><a id="l00211" name="l00211"></a><span class="lineno"> 211</span> <span class="keyword">const</span> std::vector<MTL::Function*>& funcs);</div>
|
||||
<div class="line"><a id="l00212" name="l00212"></a><span class="lineno"> 212</span> </div>
|
||||
<div class="line"><a id="l00213" name="l00213"></a><span class="lineno"> 213</span> MTL::ComputePipelineState* get_kernel_(</div>
|
||||
<div class="line"><a id="l00214" name="l00214"></a><span class="lineno"> 214</span> <span class="keyword">const</span> std::string& name,</div>
|
||||
<div class="line"><a id="l00215" name="l00215"></a><span class="lineno"> 215</span> <span class="keyword">const</span> MTL::Function* mtl_function);</div>
|
||||
<div class="line"><a id="l00216" name="l00216"></a><span class="lineno"> 216</span> </div>
|
||||
<div class="line"><a id="l00217" name="l00217"></a><span class="lineno"> 217</span> MTL::ComputePipelineState* get_kernel_(</div>
|
||||
<div class="line"><a id="l00218" name="l00218"></a><span class="lineno"> 218</span> <span class="keyword">const</span> std::string& name,</div>
|
||||
<div class="line"><a id="l00219" name="l00219"></a><span class="lineno"> 219</span> <span class="keyword">const</span> MTL::Function* mtl_function,</div>
|
||||
<div class="line"><a id="l00220" name="l00220"></a><span class="lineno"> 220</span> <span class="keyword">const</span> MTL::LinkedFunctions* linked_functions);</div>
|
||||
<div class="line"><a id="l00221" name="l00221"></a><span class="lineno"> 221</span> </div>
|
||||
<div class="line"><a id="l00222" name="l00222"></a><span class="lineno"> 222</span> MTL::ComputePipelineState* get_kernel_(</div>
|
||||
<div class="line"><a id="l00223" name="l00223"></a><span class="lineno"> 223</span> <span class="keyword">const</span> std::string& base_name,</div>
|
||||
<div class="line"><a id="l00224" name="l00224"></a><span class="lineno"> 224</span> MTL::Library* mtl_lib,</div>
|
||||
<div class="line"><a id="l00225" name="l00225"></a><span class="lineno"> 225</span> <span class="keyword">const</span> std::string& hash_name,</div>
|
||||
<div class="line"><a id="l00226" name="l00226"></a><span class="lineno"> 226</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">MTLFCList</a>& func_consts = {},</div>
|
||||
<div class="line"><a id="l00227" name="l00227"></a><span class="lineno"> 227</span> <span class="keyword">const</span> std::vector<MTL::Function*>& linked_functions = {});</div>
|
||||
<div class="line"><a id="l00228" name="l00228"></a><span class="lineno"> 228</span> </div>
|
||||
<div class="line"><a id="l00229" name="l00229"></a><span class="lineno"> 229</span> std::shared_mutex kernel_mtx_;</div>
|
||||
<div class="line"><a id="l00230" name="l00230"></a><span class="lineno"> 230</span> std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;</div>
|
||||
<div class="line"><a id="l00229" name="l00229"></a><span class="lineno"> 229</span> MTL::Device* device_;</div>
|
||||
<div class="line"><a id="l00230" name="l00230"></a><span class="lineno"> 230</span> std::unordered_map<int32_t, DeviceStream> stream_map_;</div>
|
||||
<div class="line"><a id="l00231" name="l00231"></a><span class="lineno"> 231</span> </div>
|
||||
<div class="line"><a id="l00232" name="l00232"></a><span class="lineno"> 232</span> std::shared_mutex library_mtx_;</div>
|
||||
<div class="line"><a id="l00233" name="l00233"></a><span class="lineno"> 233</span> std::unordered_map<std::string, MTL::Library*> library_map_;</div>
|
||||
<div class="line"><a id="l00234" name="l00234"></a><span class="lineno"> 234</span> <span class="keyword">const</span> MTL::ResidencySet* residency_set_{<span class="keyword">nullptr</span>};</div>
|
||||
<div class="line"><a id="l00235" name="l00235"></a><span class="lineno"> 235</span> std::string arch_;</div>
|
||||
<div class="line"><a id="l00236" name="l00236"></a><span class="lineno"> 236</span>};</div>
|
||||
<div class="line"><a id="l00232" name="l00232"></a><span class="lineno"> 232</span> std::shared_mutex kernel_mtx_;</div>
|
||||
<div class="line"><a id="l00233" name="l00233"></a><span class="lineno"> 233</span> std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;</div>
|
||||
<div class="line"><a id="l00234" name="l00234"></a><span class="lineno"> 234</span> </div>
|
||||
<div class="line"><a id="l00235" name="l00235"></a><span class="lineno"> 235</span> std::shared_mutex library_mtx_;</div>
|
||||
<div class="line"><a id="l00236" name="l00236"></a><span class="lineno"> 236</span> std::unordered_map<std::string, MTL::Library*> library_map_;</div>
|
||||
<div class="line"><a id="l00237" name="l00237"></a><span class="lineno"> 237</span> <span class="keyword">const</span> MTL::ResidencySet* residency_set_{<span class="keyword">nullptr</span>};</div>
|
||||
<div class="line"><a id="l00238" name="l00238"></a><span class="lineno"> 238</span> std::string arch_;</div>
|
||||
<div class="line"><a id="l00239" name="l00239"></a><span class="lineno"> 239</span>};</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00237" name="l00237"></a><span class="lineno"> 237</span> </div>
|
||||
<div class="line"><a id="l00238" name="l00238"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core_1_1metal.html#a910797b74824e6ee576fbb533dee8b57"> 238</a></span><a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a>& <a class="code hl_function" href="namespacemlx_1_1core_1_1metal.html#a910797b74824e6ee576fbb533dee8b57">device</a>(<a class="code hl_struct" href="structmlx_1_1core_1_1_device.html">mlx::core::Device</a>);</div>
|
||||
<div class="line"><a id="l00239" name="l00239"></a><span class="lineno"> 239</span> </div>
|
||||
<div class="line"><a id="l00240" name="l00240"></a><span class="lineno"> 240</span>} <span class="comment">// namespace mlx::core::metal</span></div>
|
||||
<div class="line"><a id="l00240" name="l00240"></a><span class="lineno"> 240</span> </div>
|
||||
<div class="line"><a id="l00241" name="l00241"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core_1_1metal.html#a910797b74824e6ee576fbb533dee8b57"> 241</a></span><a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a>& <a class="code hl_function" href="namespacemlx_1_1core_1_1metal.html#a910797b74824e6ee576fbb533dee8b57">device</a>(<a class="code hl_struct" href="structmlx_1_1core_1_1_device.html">mlx::core::Device</a>);</div>
|
||||
<div class="line"><a id="l00242" name="l00242"></a><span class="lineno"> 242</span> </div>
|
||||
<div class="line"><a id="l00243" name="l00243"></a><span class="lineno"> 243</span>} <span class="comment">// namespace mlx::core::metal</span></div>
|
||||
<div class="ttc" id="aarray_8h_html"><div class="ttname"><a href="array_8h.html">array.h</a></div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1array_html"><div class="ttname"><a href="classmlx_1_1core_1_1array.html">mlx::core::array</a></div><div class="ttdef"><b>Definition</b> array.h:20</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html">mlx::core::metal::Device</a></div><div class="ttdef"><b>Definition</b> device.h:128</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html">mlx::core::metal::Device</a></div><div class="ttdef"><b>Definition</b> device.h:131</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a03a2f0c712660a1bd437cb16e4aba79f"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a03a2f0c712660a1bd437cb16e4aba79f">mlx::core::metal::Device::set_residency_set</a></div><div class="ttdeci">void set_residency_set(const MTL::ResidencySet *residency_set)</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a064e1cb6a16de7a0619f6447622350f8"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a064e1cb6a16de7a0619f6447622350f8">mlx::core::metal::Device::get_command_buffer_ops</a></div><div class="ttdeci">int get_command_buffer_ops(int index)</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a31dba377f2be44a746db10d1b9367653"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653">mlx::core::metal::Device::mtl_device</a></div><div class="ttdeci">MTL::Device * mtl_device()</div><div class="ttdef"><b>Definition</b> device.h:135</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a31dba377f2be44a746db10d1b9367653"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653">mlx::core::metal::Device::mtl_device</a></div><div class="ttdeci">MTL::Device * mtl_device()</div><div class="ttdef"><b>Definition</b> device.h:138</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a45945f2efcd242d915ffa2171e92bf9d"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a45945f2efcd242d915ffa2171e92bf9d">mlx::core::metal::Device::register_library</a></div><div class="ttdeci">void register_library(const std::string &lib_name, const std::string &lib_path)</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a4f39c28c6cdd1d2da1918f5871bcba6e"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a4f39c28c6cdd1d2da1918f5871bcba6e">mlx::core::metal::Device::~Device</a></div><div class="ttdeci">~Device()</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a5fe3970fbe92ccc55fce4241ffbe5210"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a5fe3970fbe92ccc55fce4241ffbe5210">mlx::core::metal::Device::get_command_buffer</a></div><div class="ttdeci">MTL::CommandBuffer * get_command_buffer(int index)</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a60689f97347811b27e8c5ca23e0372bf"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a60689f97347811b27e8c5ca23e0372bf">mlx::core::metal::Device::end_encoding</a></div><div class="ttdeci">void end_encoding(int index)</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a65f64dd8bafdc704d871fc5be5e7bc0b"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a65f64dd8bafdc704d871fc5be5e7bc0b">mlx::core::metal::Device::get_architecture</a></div><div class="ttdeci">const std::string & get_architecture()</div><div class="ttdef"><b>Definition</b> device.h:139</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a65f64dd8bafdc704d871fc5be5e7bc0b"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a65f64dd8bafdc704d871fc5be5e7bc0b">mlx::core::metal::Device::get_architecture</a></div><div class="ttdeci">const std::string & get_architecture()</div><div class="ttdef"><b>Definition</b> device.h:142</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a6810c4dcbcfbf93fc51d42aa5ff0fc3a"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a6810c4dcbcfbf93fc51d42aa5ff0fc3a">mlx::core::metal::Device::get_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_kernel(const std::string &base_name, MTL::Library *mtl_lib, const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a6e33e2b1287324fb4a6575e0da5e5881"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a6e33e2b1287324fb4a6575e0da5e5881">mlx::core::metal::Device::argument_encoder</a></div><div class="ttdeci">MTL::ArgumentEncoder * argument_encoder(const std::vector< MTL::ArgumentDescriptor * > &arg_descs) const</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a72ad17c96fc6ce825bc77f0bed657901"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a72ad17c96fc6ce825bc77f0bed657901">mlx::core::metal::Device::add_temporaries</a></div><div class="ttdeci">void add_temporaries(std::vector< array > arrays, int index)</div></div>
|
||||
@ -383,7 +386,7 @@ $(function(){ initResizable(false); });
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a7a33d4d601423a3d3c23d5ad7072abb6"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a7a33d4d601423a3d3c23d5ad7072abb6">mlx::core::metal::Device::increment_command_buffer_ops</a></div><div class="ttdeci">void increment_command_buffer_ops(int index)</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a8135ae2a8c1e6f3861e84d4e60c28b67"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a8135ae2a8c1e6f3861e84d4e60c28b67">mlx::core::metal::Device::new_queue</a></div><div class="ttdeci">void new_queue(int index)</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a95248f1387824067fd4fed23ace5ac0c"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a95248f1387824067fd4fed23ace5ac0c">mlx::core::metal::Device::commit_command_buffer</a></div><div class="ttdeci">void commit_command_buffer(int index)</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a99ff72689b7beb65ad4541391b0eeabf"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a99ff72689b7beb65ad4541391b0eeabf">mlx::core::metal::Device::register_library</a></div><div class="ttdeci">void register_library(const std::string &lib_name)</div><div class="ttdef"><b>Definition</b> device.h:157</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a99ff72689b7beb65ad4541391b0eeabf"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a99ff72689b7beb65ad4541391b0eeabf">mlx::core::metal::Device::register_library</a></div><div class="ttdeci">void register_library(const std::string &lib_name)</div><div class="ttdef"><b>Definition</b> device.h:160</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_abf59a4addb5473f9e814e3651ba85f06"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#abf59a4addb5473f9e814e3651ba85f06">mlx::core::metal::Device::Device</a></div><div class="ttdeci">Device(const Device &)=delete</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_acb90010af0cffe27fd8cc6c253d3a576"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#acb90010af0cffe27fd8cc6c253d3a576">mlx::core::metal::Device::add_temporary</a></div><div class="ttdeci">void add_temporary(array arr, int index)</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_ad1d6382fd18a46b1906e1b43e0bd2e73"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#ad1d6382fd18a46b1906e1b43e0bd2e73">mlx::core::metal::Device::operator=</a></div><div class="ttdeci">Device & operator=(const Device &)=delete</div></div>
|
||||
@ -402,31 +405,32 @@ $(function(){ initResizable(false); });
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></div><div class="ttdef"><b>Definition</b> device.h:41</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a1e41477f2f489e38499f7830a91c9810"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a1e41477f2f489e38499f7830a91c9810">mlx::core::metal::CommandEncoder::dispatchThreads</a></div><div class="ttdeci">void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims)</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a2334774486f447213ee997e55c2e52a3"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a2334774486f447213ee997e55c2e52a3">mlx::core::metal::CommandEncoder::CommandEncoder</a></div><div class="ttdeci">CommandEncoder(MTL::CommandBuffer *cbuf)</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a27ded7e54bc1712063c874646b445509"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a27ded7e54bc1712063c874646b445509">mlx::core::metal::CommandEncoder::inputs</a></div><div class="ttdeci">std::unordered_set< const void * > & inputs()</div><div class="ttdef"><b>Definition</b> device.h:76</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a27ded7e54bc1712063c874646b445509"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a27ded7e54bc1712063c874646b445509">mlx::core::metal::CommandEncoder::inputs</a></div><div class="ttdeci">std::unordered_set< const void * > & inputs()</div><div class="ttdef"><b>Definition</b> device.h:77</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a3f42a1362b4a513fa89e7b3dcc570a8e"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a3f42a1362b4a513fa89e7b3dcc570a8e">mlx::core::metal::CommandEncoder::operator=</a></div><div class="ttdeci">CommandEncoder & operator=(const CommandEncoder &)=delete</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a48b548a0b15f9d1279c938a1c6167034"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a48b548a0b15f9d1279c938a1c6167034">mlx::core::metal::CommandEncoder::start_concurrent</a></div><div class="ttdeci">ConcurrentContext start_concurrent()</div><div class="ttdef"><b>Definition</b> device.h:70</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a48b548a0b15f9d1279c938a1c6167034"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a48b548a0b15f9d1279c938a1c6167034">mlx::core::metal::CommandEncoder::start_concurrent</a></div><div class="ttdeci">ConcurrentContext start_concurrent()</div><div class="ttdef"><b>Definition</b> device.h:71</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a6a2e28e542eaa2886041bddd51ff6522"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a6a2e28e542eaa2886041bddd51ff6522">mlx::core::metal::CommandEncoder::set_output_array</a></div><div class="ttdeci">void set_output_array(array &a, int idx, int64_t offset=0)</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a74bcd8e35f80f5a62db48c4a2bb0173e"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a74bcd8e35f80f5a62db48c4a2bb0173e">mlx::core::metal::CommandEncoder::dispatchThreadgroups</a></div><div class="ttdeci">void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims)</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a9b6dd221ccd2d939d544004cb6279198"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a9b6dd221ccd2d939d544004cb6279198">mlx::core::metal::CommandEncoder::~CommandEncoder</a></div><div class="ttdeci">~CommandEncoder()</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_aac45ab0630ea32cf7d15c7ba3e229966"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aac45ab0630ea32cf7d15c7ba3e229966">mlx::core::metal::CommandEncoder::operator-></a></div><div class="ttdeci">MTL::ComputeCommandEncoder * operator->()</div><div class="ttdef"><b>Definition</b> device.h:61</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_ab69ff0d7f14b9b59db4df0608193dce4"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ab69ff0d7f14b9b59db4df0608193dce4">mlx::core::metal::CommandEncoder::set_input_array</a></div><div class="ttdeci">void set_input_array(const array &a, int idx, int64_t offset=0)</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_ac68ca977b5bde5434284ce7979647f14"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ac68ca977b5bde5434284ce7979647f14">mlx::core::metal::CommandEncoder::CommandEncoder</a></div><div class="ttdeci">CommandEncoder(const CommandEncoder &)=delete</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_aefa48740fdee884f02e2d379bca4e78f"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aefa48740fdee884f02e2d379bca4e78f">mlx::core::metal::CommandEncoder::outputs</a></div><div class="ttdeci">std::unordered_set< const void * > outputs()</div><div class="ttdef"><b>Definition</b> device.h:81</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html">mlx::core::metal::DeviceStream</a></div><div class="ttdef"><b>Definition</b> device.h:102</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a1c4397732f64f5811381dd01e30e020e"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a1c4397732f64f5811381dd01e30e020e">mlx::core::metal::DeviceStream::~DeviceStream</a></div><div class="ttdeci">~DeviceStream()</div><div class="ttdef"><b>Definition</b> device.h:104</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a55a7a92c6abad369c99a5ede7a2521b9"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a55a7a92c6abad369c99a5ede7a2521b9">mlx::core::metal::DeviceStream::outputs</a></div><div class="ttdeci">std::unordered_map< const void *, std::shared_ptr< Fence > > outputs</div><div class="ttdef"><b>Definition</b> device.h:112</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a573326bc8b48e39076850c7bf52ad0d7"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a573326bc8b48e39076850c7bf52ad0d7">mlx::core::metal::DeviceStream::DeviceStream</a></div><div class="ttdeci">DeviceStream(MTL::CommandQueue *queue)</div><div class="ttdef"><b>Definition</b> device.h:103</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a58e435217b9922f882507ebf48bfbbdd"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a58e435217b9922f882507ebf48bfbbdd">mlx::core::metal::DeviceStream::encoder</a></div><div class="ttdeci">std::unique_ptr< CommandEncoder > encoder</div><div class="ttdef"><b>Definition</b> device.h:123</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a6fa08cca881fc3798ae45994a11a4fcd"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a6fa08cca881fc3798ae45994a11a4fcd">mlx::core::metal::DeviceStream::fence_mtx</a></div><div class="ttdeci">std::mutex fence_mtx</div><div class="ttdef"><b>Definition</b> device.h:114</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a77c75a63c51ea56815a86bd882ed190d"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">mlx::core::metal::DeviceStream::queue</a></div><div class="ttdeci">MTL::CommandQueue * queue</div><div class="ttdef"><b>Definition</b> device.h:110</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a876199de8da1efa9a362451029638499"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a876199de8da1efa9a362451029638499">mlx::core::metal::DeviceStream::fence</a></div><div class="ttdeci">std::shared_ptr< Fence > fence</div><div class="ttdef"><b>Definition</b> device.h:124</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a99183c92599edfeb75f7fa0f37e1d9eb"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb">mlx::core::metal::DeviceStream::buffer</a></div><div class="ttdeci">MTL::CommandBuffer * buffer</div><div class="ttdef"><b>Definition</b> device.h:118</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_ab6048b329e65a59033834f3bdd351782"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#ab6048b329e65a59033834f3bdd351782">mlx::core::metal::DeviceStream::buffer_ops</a></div><div class="ttdeci">int buffer_ops</div><div class="ttdef"><b>Definition</b> device.h:119</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_aee88009117dfff1ad121eabe28d5f3de"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#aee88009117dfff1ad121eabe28d5f3de">mlx::core::metal::DeviceStream::temporaries</a></div><div class="ttdeci">std::vector< array > temporaries</div><div class="ttdef"><b>Definition</b> device.h:125</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_fence_html"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_fence.html">mlx::core::metal::Fence</a></div><div class="ttdef"><b>Definition</b> device.h:94</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_fence_html_a30bee4957ae595e04922952a8010fc79"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_fence.html#a30bee4957ae595e04922952a8010fc79">mlx::core::metal::Fence::Fence</a></div><div class="ttdeci">Fence(MTL::Fence *fence)</div><div class="ttdef"><b>Definition</b> device.h:95</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_fence_html_a4940c1aece13814af7727de9abb511f2"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_fence.html#a4940c1aece13814af7727de9abb511f2">mlx::core::metal::Fence::~Fence</a></div><div class="ttdeci">~Fence()</div><div class="ttdef"><b>Definition</b> device.h:96</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_fence_html_aeccd8f2b81418ae9fc446ae2b6e15b87"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">mlx::core::metal::Fence::fence</a></div><div class="ttdeci">MTL::Fence * fence</div><div class="ttdef"><b>Definition</b> device.h:99</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_ad538ae88f90560063f9ba502e2795991"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ad538ae88f90560063f9ba502e2795991">mlx::core::metal::CommandEncoder::maybeInsertBarrier</a></div><div class="ttdeci">void maybeInsertBarrier()</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_aefa48740fdee884f02e2d379bca4e78f"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aefa48740fdee884f02e2d379bca4e78f">mlx::core::metal::CommandEncoder::outputs</a></div><div class="ttdeci">std::unordered_set< const void * > outputs()</div><div class="ttdef"><b>Definition</b> device.h:82</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html">mlx::core::metal::DeviceStream</a></div><div class="ttdef"><b>Definition</b> device.h:105</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a1c4397732f64f5811381dd01e30e020e"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a1c4397732f64f5811381dd01e30e020e">mlx::core::metal::DeviceStream::~DeviceStream</a></div><div class="ttdeci">~DeviceStream()</div><div class="ttdef"><b>Definition</b> device.h:107</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a55a7a92c6abad369c99a5ede7a2521b9"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a55a7a92c6abad369c99a5ede7a2521b9">mlx::core::metal::DeviceStream::outputs</a></div><div class="ttdeci">std::unordered_map< const void *, std::shared_ptr< Fence > > outputs</div><div class="ttdef"><b>Definition</b> device.h:115</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a573326bc8b48e39076850c7bf52ad0d7"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a573326bc8b48e39076850c7bf52ad0d7">mlx::core::metal::DeviceStream::DeviceStream</a></div><div class="ttdeci">DeviceStream(MTL::CommandQueue *queue)</div><div class="ttdef"><b>Definition</b> device.h:106</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a58e435217b9922f882507ebf48bfbbdd"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a58e435217b9922f882507ebf48bfbbdd">mlx::core::metal::DeviceStream::encoder</a></div><div class="ttdeci">std::unique_ptr< CommandEncoder > encoder</div><div class="ttdef"><b>Definition</b> device.h:126</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a6fa08cca881fc3798ae45994a11a4fcd"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a6fa08cca881fc3798ae45994a11a4fcd">mlx::core::metal::DeviceStream::fence_mtx</a></div><div class="ttdeci">std::mutex fence_mtx</div><div class="ttdef"><b>Definition</b> device.h:117</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a77c75a63c51ea56815a86bd882ed190d"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">mlx::core::metal::DeviceStream::queue</a></div><div class="ttdeci">MTL::CommandQueue * queue</div><div class="ttdef"><b>Definition</b> device.h:113</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a876199de8da1efa9a362451029638499"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a876199de8da1efa9a362451029638499">mlx::core::metal::DeviceStream::fence</a></div><div class="ttdeci">std::shared_ptr< Fence > fence</div><div class="ttdef"><b>Definition</b> device.h:127</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a99183c92599edfeb75f7fa0f37e1d9eb"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb">mlx::core::metal::DeviceStream::buffer</a></div><div class="ttdeci">MTL::CommandBuffer * buffer</div><div class="ttdef"><b>Definition</b> device.h:121</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_ab6048b329e65a59033834f3bdd351782"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#ab6048b329e65a59033834f3bdd351782">mlx::core::metal::DeviceStream::buffer_ops</a></div><div class="ttdeci">int buffer_ops</div><div class="ttdef"><b>Definition</b> device.h:122</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_aee88009117dfff1ad121eabe28d5f3de"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#aee88009117dfff1ad121eabe28d5f3de">mlx::core::metal::DeviceStream::temporaries</a></div><div class="ttdeci">std::vector< array > temporaries</div><div class="ttdef"><b>Definition</b> device.h:128</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_fence_html"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_fence.html">mlx::core::metal::Fence</a></div><div class="ttdef"><b>Definition</b> device.h:97</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_fence_html_a30bee4957ae595e04922952a8010fc79"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_fence.html#a30bee4957ae595e04922952a8010fc79">mlx::core::metal::Fence::Fence</a></div><div class="ttdeci">Fence(MTL::Fence *fence)</div><div class="ttdef"><b>Definition</b> device.h:98</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_fence_html_a4940c1aece13814af7727de9abb511f2"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_fence.html#a4940c1aece13814af7727de9abb511f2">mlx::core::metal::Fence::~Fence</a></div><div class="ttdeci">~Fence()</div><div class="ttdef"><b>Definition</b> device.h:99</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_fence_html_aeccd8f2b81418ae9fc446ae2b6e15b87"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">mlx::core::metal::Fence::fence</a></div><div class="ttdeci">MTL::Fence * fence</div><div class="ttdef"><b>Definition</b> device.h:102</div></div>
|
||||
</div><!-- fragment --></div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
<hr class="footer"/><address class="footer"><small>
|
||||
|
@ -864,7 +864,7 @@
|
||||
<article class="bd-article">
|
||||
|
||||
<section id="custom-metal-kernels">
|
||||
<h1>Custom Metal Kernels<a class="headerlink" href="#custom-metal-kernels" title="Link to this heading">#</a></h1>
|
||||
<span id="id1"></span><h1>Custom Metal Kernels<a class="headerlink" href="#custom-metal-kernels" title="Link to this heading">#</a></h1>
|
||||
<p>MLX supports writing custom Metal kernels through the Python and C++ APIs.</p>
|
||||
<section id="simple-example">
|
||||
<h2>Simple Example<a class="headerlink" href="#simple-example" title="Link to this heading">#</a></h2>
|
||||
@ -947,6 +947,9 @@ All the attributes defined in Table 5.8 of the <a class="reference external" hre
|
||||
<span class="k">template</span><span class="w"> </span><span class="p">[[</span><span class="n">host_name</span><span class="p">(</span><span class="s">"custom_kernel_myexp_float"</span><span class="p">)]]</span><span class="w"> </span><span class="p">[[</span><span class="n">kernel</span><span class="p">]]</span><span class="w"> </span><span class="k">decltype</span><span class="p">(</span><span class="n">custom_kernel_myexp_float</span><span class="o"><</span><span class="kt">float</span><span class="o">></span><span class="p">)</span><span class="w"> </span><span class="n">custom_kernel_myexp_float</span><span class="o"><</span><span class="kt">float</span><span class="o">></span><span class="p">;</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Note: <code class="docutils literal notranslate"><span class="pre">grid</span></code> and <code class="docutils literal notranslate"><span class="pre">threadgroup</span></code> are parameters to the Metal <a class="reference external" href="https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads">dispatchThreads</a> function.
|
||||
This means we will launch <code class="docutils literal notranslate"><span class="pre">mx.prod(grid)</span></code> threads, subdivided into <code class="docutils literal notranslate"><span class="pre">threadgroup</span></code> size threadgroups.
|
||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.</p>
|
||||
<p>Passing <code class="docutils literal notranslate"><span class="pre">verbose=True</span></code> to <code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel.__call__</span></code> will print the generated code for debugging purposes.</p>
|
||||
</section>
|
||||
<section id="using-shape-strides">
|
||||
|
18
docs/build/html/doxygen_crawl.html
vendored
18
docs/build/html/doxygen_crawl.html
vendored
@ -4330,9 +4330,9 @@
|
||||
<a href="kernels_8h.html#a195b86cad5bb99aa1bcd23952305af6b"/>
|
||||
<a href="kernels_8h.html#a1d4cffc3c78067b3d9a62d64f3fb686f"/>
|
||||
<a href="kernels_8h.html#a35a412f688d79eb47e42d20a7c8650ee"/>
|
||||
<a href="kernels_8h.html#a3bd386cb6db09f636963ce66ceaf8647"/>
|
||||
<a href="kernels_8h.html#a4decd4a07d91487e6903f6e3c8b7513a"/>
|
||||
<a href="kernels_8h.html#a4e809746f48e5dcf7fa63215d3f5e33e"/>
|
||||
<a href="kernels_8h.html#a51c4bb09230348bd0252e22bfdc9bc89"/>
|
||||
<a href="kernels_8h.html#a54eb3b65375022428aab5f810e40624b"/>
|
||||
<a href="kernels_8h.html#a76f614e9956a6ca05a9be4db5a483446"/>
|
||||
<a href="kernels_8h.html#a7aa91fcfe8b9caa42d60a957f11bfe6b"/>
|
||||
@ -4443,9 +4443,9 @@
|
||||
<a href="metal_2kernels_2unary_8h.html#a7c7690f0df9d2acc60b63be58d9c7777"/>
|
||||
<a href="metal_2kernels_2unary_8h.html#ac965f8d3ed62f8580dbfb645e83d4ae5"/>
|
||||
<a href="metal_2reduce_8h.html"/>
|
||||
<a href="metal_2reduce_8h.html#a3ab0fd997d9a35782106ff083a72e098"/>
|
||||
<a href="metal_2reduce_8h.html#aa0332c64ee9965f05026c30a0b778000"/>
|
||||
<a href="metal_2reduce_8h.html#ab1eeca8ec6fa31819ee108fa6ed2c41b"/>
|
||||
<a href="metal_2reduce_8h.html#af7b7ca7c6aa87558d9f98cee5c7a99a8"/>
|
||||
<a href="metal_2slicing_8h.html"/>
|
||||
<a href="metal_2slicing_8h.html#a050299d0d366ca5c9d09d1004dcc3e7d"/>
|
||||
<a href="metal_2slicing_8h.html#a59048c5ff114c101a496bf33f62e3de9"/>
|
||||
@ -4850,9 +4850,11 @@
|
||||
<a href="namespacemlx_1_1core.html#a3a6f43c2485f0d42293184f1aecbeaee"/>
|
||||
<a href="namespacemlx_1_1core.html#a3a8f6f0af477788c4f0aa98abfc5f1ab"/>
|
||||
<a href="namespacemlx_1_1core.html#a3a8fe7ba84714dbb5fdc81e93a07abc8"/>
|
||||
<a href="namespacemlx_1_1core.html#a3ab0fd997d9a35782106ff083a72e098"/>
|
||||
<a href="namespacemlx_1_1core.html#a3ac798e65e59fe10b7fb5c522efce782"/>
|
||||
<a href="namespacemlx_1_1core.html#a3b900ab319948c5a01a3ecd30a709027"/>
|
||||
<a href="namespacemlx_1_1core.html#a3ba20a804c306067b7023259429e0e48"/>
|
||||
<a href="namespacemlx_1_1core.html#a3bd386cb6db09f636963ce66ceaf8647"/>
|
||||
<a href="namespacemlx_1_1core.html#a3c41a304126bc225bdc68062d1eb6e7e"/>
|
||||
<a href="namespacemlx_1_1core.html#a3cc5c154e4ad9a83ad43da8513146fdc"/>
|
||||
<a href="namespacemlx_1_1core.html#a3d2b2929ed4636e9e2b86e125b2e57d9"/>
|
||||
@ -4900,7 +4902,6 @@
|
||||
<a href="namespacemlx_1_1core.html#a514263e63f6825b490203ca586864687"/>
|
||||
<a href="namespacemlx_1_1core.html#a514cf8b4e6f0a6af3a867e752f4338f7"/>
|
||||
<a href="namespacemlx_1_1core.html#a517019d42d4e426b7b98e1c719bb47ce"/>
|
||||
<a href="namespacemlx_1_1core.html#a51c4bb09230348bd0252e22bfdc9bc89"/>
|
||||
<a href="namespacemlx_1_1core.html#a5287610200ff573730c9c92413f48881"/>
|
||||
<a href="namespacemlx_1_1core.html#a54833be1d44bc3adfc9ea218fc3685bd"/>
|
||||
<a href="namespacemlx_1_1core.html#a54863a54f258acf2b5c734950618e4e1"/>
|
||||
@ -5272,7 +5273,6 @@
|
||||
<a href="namespacemlx_1_1core.html#af69db7def588d7da430434a69456e29c"/>
|
||||
<a href="namespacemlx_1_1core.html#af7577c91b8c43682f0ebc9eb9758aae4"/>
|
||||
<a href="namespacemlx_1_1core.html#af776fd91dd60594dcfebbafd17f19068"/>
|
||||
<a href="namespacemlx_1_1core.html#af7b7ca7c6aa87558d9f98cee5c7a99a8"/>
|
||||
<a href="namespacemlx_1_1core.html#af7eea1682a38d363c56a066321e6d526"/>
|
||||
<a href="namespacemlx_1_1core.html#af810587a17e692f4eec256d3c3cd27de"/>
|
||||
<a href="namespacemlx_1_1core.html#af84ed854132c1514dca5a524fdb7ed05"/>
|
||||
@ -5933,11 +5933,11 @@
|
||||
<a href="quantized_8h.html#a0386011c52d03e60885a31e6fbd903dd"/>
|
||||
<a href="quantized_8h.html#a07b26d2d0b0d65dfe925c452c453fa42"/>
|
||||
<a href="quantized_8h.html#a0ba59096494f1001c195312571523ae9"/>
|
||||
<a href="quantized_8h.html#a1546533c5b925b2fbb3bec870ec7487a"/>
|
||||
<a href="quantized_8h.html#a1a66b061c46383952a0f067c3848971f"/>
|
||||
<a href="quantized_8h.html#a2ce135e392dbf9a3e5180fb083792ed7"/>
|
||||
<a href="quantized_8h.html#a3ab400746ad77be89c30d25638e01698"/>
|
||||
<a href="quantized_8h.html#a47bcf4a14566e01e14bd3c155811db59"/>
|
||||
<a href="quantized_8h.html#a4a8c8db7d5d480733726fd6d1a645e12"/>
|
||||
<a href="quantized_8h.html#a530b720e123e59d73ea89a0a2d0946b7"/>
|
||||
<a href="quantized_8h.html#a6076203615038eb06816158f7b3869c6"/>
|
||||
<a href="quantized_8h.html#a62969a218d93680f5e35d0c61b160b99"/>
|
||||
@ -5952,6 +5952,7 @@
|
||||
<a href="quantized_8h.html#aa69e143d646fad332c1a53e8c9b337b7"/>
|
||||
<a href="quantized_8h.html#ab1ae143eba2afceb8df63f38b26f9a84"/>
|
||||
<a href="quantized_8h.html#ab364d58ab652e3ad87a8f80910556071"/>
|
||||
<a href="quantized_8h.html#ab8243818512d6078d23e6ffb65fd7bb8"/>
|
||||
<a href="quantized_8h.html#aba7687e6f8f1d29c0a1b2a3db150bd81"/>
|
||||
<a href="quantized_8h.html#abe2e3ef0ee4ec2cb61dc5330ad463d10"/>
|
||||
<a href="quantized_8h.html#accab1f9e17a65242347c051f98e4c0be"/>
|
||||
@ -6016,8 +6017,10 @@
|
||||
<a href="reduce__all_8h.html"/>
|
||||
<a href="reduce__all_8h.html#a99ef48ae72b3e715c5f4d7ea07cd213d"/>
|
||||
<a href="reduce__col_8h.html"/>
|
||||
<a href="reduce__col_8h.html#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d"/>
|
||||
<a href="reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385"/>
|
||||
<a href="reduce__col_8h.html#adf7aeb18cd1d5042cf6d9b46b582d8ce"/>
|
||||
<a href="reduce__col_8h.html#a5b4f4c4c247ad341ff8d31dcbbbce0eb"/>
|
||||
<a href="reduce__col_8h.html#a7c378443a2b6f4d9210db8a21a9ac4f5"/>
|
||||
<a href="reduce__init_8h.html"/>
|
||||
<a href="reduce__init_8h.html#a0088604ac2eaa6940689ff12c4ba5fc2"/>
|
||||
<a href="reduce__row_8h.html"/>
|
||||
@ -6051,7 +6054,7 @@
|
||||
<a href="scheduler_8h.html#aa2d4eacf5d5cbc778a51aafd4fd8e4d7"/>
|
||||
<a href="scheduler_8h.html#ae856e468c2f7c8f8ec672522cc13730b"/>
|
||||
<a href="sdpa__vector_8h.html"/>
|
||||
<a href="sdpa__vector_8h.html#a6f0d7918430064bab910bdaa6c64e927"/>
|
||||
<a href="sdpa__vector_8h.html#a4bf36f16e16c1c62d9b243573568e5ae"/>
|
||||
<a href="sort_8h.html"/>
|
||||
<a href="sort_8h.html#a0386011c52d03e60885a31e6fbd903dd"/>
|
||||
<a href="sort_8h.html#a32cbe4163b8b0f5cb2c97b256119a4b2"/>
|
||||
@ -6929,6 +6932,7 @@
|
||||
<a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aac45ab0630ea32cf7d15c7ba3e229966"/>
|
||||
<a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ab69ff0d7f14b9b59db4df0608193dce4"/>
|
||||
<a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ac68ca977b5bde5434284ce7979647f14"/>
|
||||
<a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ad538ae88f90560063f9ba502e2795991"/>
|
||||
<a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aefa48740fdee884f02e2d379bca4e78f"/>
|
||||
<a href="structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html"/>
|
||||
<a href="structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html#a28bafec56edec3091e8716d8ccfb6ee1"/>
|
||||
|
1
docs/build/html/functions_func_m.html
vendored
1
docs/build/html/functions_func_m.html
vendored
@ -93,6 +93,7 @@ $(function(){ initResizable(false); });
|
||||
<li>Matmul() : <a class="el" href="classmlx_1_1core_1_1_matmul.html#adef92f30ab35e540ccb316ea6b94e6f7">mlx::core::Matmul</a></li>
|
||||
<li>max() : <a class="el" href="structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#a92320d40a58218e40cc414986ac95c50">metal::_numeric_limits_impl< bfloat16_t ></a></li>
|
||||
<li>Maximum() : <a class="el" href="classmlx_1_1core_1_1_maximum.html#a28389307e385efe1b2955b86b115e816">mlx::core::Maximum</a></li>
|
||||
<li>maybeInsertBarrier() : <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ad538ae88f90560063f9ba502e2795991">mlx::core::metal::CommandEncoder</a></li>
|
||||
<li>merge_partition() : <a class="el" href="struct_block_merge_sort.html#ab2300cbecb23f3433bad888924c831ca">BlockMergeSort< val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp ></a>, <a class="el" href="struct_kernel_multi_block_merge_sort.html#ab15895b4233aba0e279cc44a07a201fe">KernelMultiBlockMergeSort< val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp ></a></li>
|
||||
<li>merge_step() : <a class="el" href="struct_block_merge_sort.html#ab65f190edf1851b37c39ad49ce99a43c">BlockMergeSort< val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp ></a></li>
|
||||
<li>min() : <a class="el" href="structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#adaed80031f5ca0ff69d30ec4c5d0c98f">metal::_numeric_limits_impl< bfloat16_t ></a></li>
|
||||
|
1
docs/build/html/functions_m.html
vendored
1
docs/build/html/functions_m.html
vendored
@ -102,6 +102,7 @@ $(function(){ initResizable(false); });
|
||||
<li>max_exponent : <a class="el" href="structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#a61bb136f819fa392c50bdf3c38f3aad2">metal::_numeric_limits_impl< bfloat16_t ></a></li>
|
||||
<li>max_exponent10 : <a class="el" href="structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#a76bfb2deb0e0afc011f77bf5a6d0ed94">metal::_numeric_limits_impl< bfloat16_t ></a></li>
|
||||
<li>Maximum() : <a class="el" href="classmlx_1_1core_1_1_maximum.html#a28389307e385efe1b2955b86b115e816">mlx::core::Maximum</a></li>
|
||||
<li>maybeInsertBarrier() : <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ad538ae88f90560063f9ba502e2795991">mlx::core::metal::CommandEncoder</a></li>
|
||||
<li>merge_partition() : <a class="el" href="struct_block_merge_sort.html#ab2300cbecb23f3433bad888924c831ca">BlockMergeSort< val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp ></a>, <a class="el" href="struct_kernel_multi_block_merge_sort.html#ab15895b4233aba0e279cc44a07a201fe">KernelMultiBlockMergeSort< val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp ></a></li>
|
||||
<li>merge_step() : <a class="el" href="struct_block_merge_sort.html#ab65f190edf1851b37c39ad49ce99a43c">BlockMergeSort< val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp ></a></li>
|
||||
<li>Min : <a class="el" href="classmlx_1_1core_1_1distributed_1_1_all_reduce.html#abb4560980e5d01aed14175ce8f6fc924a4f685dcd48e6614d6bb2ccda4f2686ef">mlx::core::distributed::AllReduce</a>, <a class="el" href="classmlx_1_1core_1_1_reduce.html#a0848518b16ae6d4043d6be247bdf31c9a0d3d1f5c94725bdc42fa692e2c074418">mlx::core::Reduce</a>, <a class="el" href="classmlx_1_1core_1_1_scan.html#a47bf2ec54ead4b8f00f9f188518630f1a7d2ee8f14f2e70a9d47170fecc6da898">mlx::core::Scan</a>, <a class="el" href="classmlx_1_1core_1_1_scatter.html#a614d19af11dc30644b2b4941033b613cad914e4c3475ce9858f2de4bf35dcfdbf">mlx::core::Scatter</a></li>
|
||||
|
4
docs/build/html/globals_c.html
vendored
4
docs/build/html/globals_c.html
vendored
@ -92,8 +92,10 @@ $(function(){ initResizable(false); });
|
||||
<li>can_convert_to_bfloat : <a class="el" href="backend_2metal_2kernels_2bf16_8h.html#aae77817d261452b2f001f4d947a3e04e">bf16.h</a></li>
|
||||
<li>can_convert_to_complex64 : <a class="el" href="backend_2metal_2kernels_2complex_8h.html#a4f90ad54f4fae363e8d3cc41d539557b">complex.h</a></li>
|
||||
<li>ceildiv() : <a class="el" href="backend_2metal_2kernels_2utils_8h.html#a8e5a4b0fb5d018d7b078d147efe4f1e3">utils.h</a></li>
|
||||
<li>col_reduce_2pass() : <a class="el" href="reduce__col_8h.html#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d">reduce_col.h</a></li>
|
||||
<li>col_reduce_longcolumn() : <a class="el" href="reduce__col_8h.html#a5b4f4c4c247ad341ff8d31dcbbbce0eb">reduce_col.h</a></li>
|
||||
<li>col_reduce_looped() : <a class="el" href="reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385">reduce_col.h</a></li>
|
||||
<li>col_reduce_small() : <a class="el" href="reduce__col_8h.html#adf7aeb18cd1d5042cf6d9b46b582d8ce">reduce_col.h</a></li>
|
||||
<li>col_reduce_small() : <a class="el" href="reduce__col_8h.html#a7c378443a2b6f4d9210db8a21a9ac4f5">reduce_col.h</a></li>
|
||||
<li>complex_binop : <a class="el" href="types_2complex_8h.html#a9c7995d495359894e1b30c0f1678d6bd">complex.h</a></li>
|
||||
<li>complex_binop_helper : <a class="el" href="types_2complex_8h.html#ac6890f9852de12339b09b65757ebc8c4">complex.h</a></li>
|
||||
<li>complex_mul() : <a class="el" href="radix_8h.html#a5bfc53b531214c9ce277bebc18aa67d6">radix.h</a></li>
|
||||
|
4
docs/build/html/globals_func_c.html
vendored
4
docs/build/html/globals_func_c.html
vendored
@ -88,8 +88,10 @@ $(function(){ initResizable(false); });
|
||||
|
||||
<h3><a id="index_c" name="index_c"></a>- c -</h3><ul>
|
||||
<li>ceildiv() : <a class="el" href="backend_2metal_2kernels_2utils_8h.html#a8e5a4b0fb5d018d7b078d147efe4f1e3">utils.h</a></li>
|
||||
<li>col_reduce_2pass() : <a class="el" href="reduce__col_8h.html#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d">reduce_col.h</a></li>
|
||||
<li>col_reduce_longcolumn() : <a class="el" href="reduce__col_8h.html#a5b4f4c4c247ad341ff8d31dcbbbce0eb">reduce_col.h</a></li>
|
||||
<li>col_reduce_looped() : <a class="el" href="reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385">reduce_col.h</a></li>
|
||||
<li>col_reduce_small() : <a class="el" href="reduce__col_8h.html#adf7aeb18cd1d5042cf6d9b46b582d8ce">reduce_col.h</a></li>
|
||||
<li>col_reduce_small() : <a class="el" href="reduce__col_8h.html#a7c378443a2b6f4d9210db8a21a9ac4f5">reduce_col.h</a></li>
|
||||
<li>complex_mul() : <a class="el" href="radix_8h.html#a5bfc53b531214c9ce277bebc18aa67d6">radix.h</a></li>
|
||||
<li>complex_mul_conj() : <a class="el" href="radix_8h.html#a0e2dfd3d1dda09f47ccc64eec35629f3">radix.h</a></li>
|
||||
<li>contiguous_scan() : <a class="el" href="scan_8h.html#a60d279b9add7d56639bb209408f09d79">scan.h</a></li>
|
||||
|
3
docs/build/html/globals_func_q.html
vendored
3
docs/build/html/globals_func_q.html
vendored
@ -101,7 +101,8 @@ $(function(){ initResizable(false); });
|
||||
<li>qmv_quad_impl() : <a class="el" href="quantized_8h.html#ad5cf1cf63656bc1780685d22169cd4ef">quantized.h</a></li>
|
||||
<li>qouter() : <a class="el" href="quantized_8h.html#ae756f6817b584c60f5dcdd1d9c6b4f58">quantized.h</a></li>
|
||||
<li>qvm() : <a class="el" href="quantized_8h.html#ad84f7d5ab9e32dbbe3ca759ae5d5d5c5">quantized.h</a></li>
|
||||
<li>qvm_impl() : <a class="el" href="quantized_8h.html#a4a8c8db7d5d480733726fd6d1a645e12">quantized.h</a></li>
|
||||
<li>qvm_impl() : <a class="el" href="quantized_8h.html#a1546533c5b925b2fbb3bec870ec7487a">quantized.h</a></li>
|
||||
<li>qvm_split_k() : <a class="el" href="quantized_8h.html#ab8243818512d6078d23e6ffb65fd7bb8">quantized.h</a></li>
|
||||
</ul>
|
||||
</div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
|
2
docs/build/html/globals_func_s.html
vendored
2
docs/build/html/globals_func_s.html
vendored
@ -88,7 +88,7 @@ $(function(){ initResizable(false); });
|
||||
|
||||
<h3><a id="index_s" name="index_s"></a>- s -</h3><ul>
|
||||
<li>scatter_impl() : <a class="el" href="scatter_8h.html#ad1ce39d0b6d733a95e739121fcc61bd1">scatter.h</a></li>
|
||||
<li>sdpa_vector() : <a class="el" href="sdpa__vector_8h.html#a6f0d7918430064bab910bdaa6c64e927">sdpa_vector.h</a></li>
|
||||
<li>sdpa_vector() : <a class="el" href="sdpa__vector_8h.html#a4bf36f16e16c1c62d9b243573568e5ae">sdpa_vector.h</a></li>
|
||||
<li>simd_shuffle() : <a class="el" href="backend_2metal_2kernels_2utils_8h.html#a71986ecdd7d18f975dd22c3df7421ce2">utils.h</a></li>
|
||||
<li>simd_shuffle_and_fill_up() : <a class="el" href="backend_2metal_2kernels_2utils_8h.html#a5862d5ea154c9b76cf56a630cf6385b4">utils.h</a></li>
|
||||
<li>simd_shuffle_down() : <a class="el" href="backend_2metal_2kernels_2utils_8h.html#aba6279624b1d30c525efee856a222b5c">utils.h</a></li>
|
||||
|
3
docs/build/html/globals_q.html
vendored
3
docs/build/html/globals_q.html
vendored
@ -102,7 +102,8 @@ $(function(){ initResizable(false); });
|
||||
<li>qouter() : <a class="el" href="quantized_8h.html#ae756f6817b584c60f5dcdd1d9c6b4f58">quantized.h</a></li>
|
||||
<li>QUAD_SIZE : <a class="el" href="quantized_8h.html#a803e4d5a1459844ba647aea5b004e133">quantized.h</a></li>
|
||||
<li>qvm() : <a class="el" href="quantized_8h.html#ad84f7d5ab9e32dbbe3ca759ae5d5d5c5">quantized.h</a></li>
|
||||
<li>qvm_impl() : <a class="el" href="quantized_8h.html#a4a8c8db7d5d480733726fd6d1a645e12">quantized.h</a></li>
|
||||
<li>qvm_impl() : <a class="el" href="quantized_8h.html#a1546533c5b925b2fbb3bec870ec7487a">quantized.h</a></li>
|
||||
<li>qvm_split_k() : <a class="el" href="quantized_8h.html#ab8243818512d6078d23e6ffb65fd7bb8">quantized.h</a></li>
|
||||
</ul>
|
||||
</div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
|
2
docs/build/html/globals_s.html
vendored
2
docs/build/html/globals_s.html
vendored
@ -89,7 +89,7 @@ $(function(){ initResizable(false); });
|
||||
<h3><a id="index_s" name="index_s"></a>- s -</h3><ul>
|
||||
<li>scatter_impl() : <a class="el" href="scatter_8h.html#ad1ce39d0b6d733a95e739121fcc61bd1">scatter.h</a></li>
|
||||
<li>scatter_kernels : <a class="el" href="jit_2indexing_8h.html#a768c949cd650a44c6b402fc1440c1a56">indexing.h</a></li>
|
||||
<li>sdpa_vector() : <a class="el" href="sdpa__vector_8h.html#a6f0d7918430064bab910bdaa6c64e927">sdpa_vector.h</a></li>
|
||||
<li>sdpa_vector() : <a class="el" href="sdpa__vector_8h.html#a4bf36f16e16c1c62d9b243573568e5ae">sdpa_vector.h</a></li>
|
||||
<li>simd_shuffle() : <a class="el" href="backend_2metal_2kernels_2utils_8h.html#a71986ecdd7d18f975dd22c3df7421ce2">utils.h</a></li>
|
||||
<li>simd_shuffle_and_fill_up() : <a class="el" href="backend_2metal_2kernels_2utils_8h.html#a5862d5ea154c9b76cf56a630cf6385b4">utils.h</a></li>
|
||||
<li>simd_shuffle_down() : <a class="el" href="backend_2metal_2kernels_2utils_8h.html#aba6279624b1d30c525efee856a222b5c">utils.h</a></li>
|
||||
|
4
docs/build/html/kernels_8h.html
vendored
4
docs/build/html/kernels_8h.html
vendored
@ -129,8 +129,8 @@ Functions</h2></td></tr>
|
||||
<tr class="separator:a84ebe6275218070f0ea320f126f64e22"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:afb57825bb763050cc9a9d194aa41ac36" id="r_afb57825bb763050cc9a9d194aa41ac36"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacemlx_1_1core.html#afb57825bb763050cc9a9d194aa41ac36">mlx::core::get_mb_sort_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &in, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &idx, int bn, int tn)</td></tr>
|
||||
<tr class="separator:afb57825bb763050cc9a9d194aa41ac36"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a51c4bb09230348bd0252e22bfdc9bc89" id="r_a51c4bb09230348bd0252e22bfdc9bc89"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacemlx_1_1core.html#a51c4bb09230348bd0252e22bfdc9bc89">mlx::core::get_reduce_init_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out)</td></tr>
|
||||
<tr class="separator:a51c4bb09230348bd0252e22bfdc9bc89"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a3bd386cb6db09f636963ce66ceaf8647" id="r_a3bd386cb6db09f636963ce66ceaf8647"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacemlx_1_1core.html#a3bd386cb6db09f636963ce66ceaf8647">mlx::core::get_reduce_init_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out)</td></tr>
|
||||
<tr class="separator:a3bd386cb6db09f636963ce66ceaf8647"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a7aa91fcfe8b9caa42d60a957f11bfe6b" id="r_a7aa91fcfe8b9caa42d60a957f11bfe6b"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b">mlx::core::get_reduce_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &in, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out, int ndim=-1, int bm=-1, int bn=-1)</td></tr>
|
||||
<tr class="separator:a7aa91fcfe8b9caa42d60a957f11bfe6b"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a84fa8e0aee321a9d614433a0b933103b" id="r_a84fa8e0aee321a9d614433a0b933103b"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacemlx_1_1core.html#a84fa8e0aee321a9d614433a0b933103b">mlx::core::get_steel_gemm_fused_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const std::string &hash_name, const <a class="el" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">metal::MTLFCList</a> &func_consts, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)</td></tr>
|
||||
|
286
docs/build/html/kernels_8h_source.html
vendored
286
docs/build/html/kernels_8h_source.html
vendored
@ -169,152 +169,154 @@ $(function(){ initResizable(false); });
|
||||
<div class="line"><a id="l00076" name="l00076"></a><span class="lineno"> 76</span> <span class="keywordtype">int</span> bn,</div>
|
||||
<div class="line"><a id="l00077" name="l00077"></a><span class="lineno"> 77</span> <span class="keywordtype">int</span> tn);</div>
|
||||
<div class="line"><a id="l00078" name="l00078"></a><span class="lineno"> 78</span> </div>
|
||||
<div class="line"><a id="l00079" name="l00079"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a51c4bb09230348bd0252e22bfdc9bc89"> 79</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a51c4bb09230348bd0252e22bfdc9bc89">get_reduce_init_kernel</a>(</div>
|
||||
<div class="line"><a id="l00079" name="l00079"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a3bd386cb6db09f636963ce66ceaf8647"> 79</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a3bd386cb6db09f636963ce66ceaf8647">get_reduce_init_kernel</a>(</div>
|
||||
<div class="line"><a id="l00080" name="l00080"></a><span class="lineno"> 80</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00081" name="l00081"></a><span class="lineno"> 81</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00082" name="l00082"></a><span class="lineno"> 82</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out);</div>
|
||||
<div class="line"><a id="l00083" name="l00083"></a><span class="lineno"> 83</span> </div>
|
||||
<div class="line"><a id="l00084" name="l00084"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b"> 84</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b">get_reduce_kernel</a>(</div>
|
||||
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"> 85</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"> 86</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00087" name="l00087"></a><span class="lineno"> 87</span> <span class="keyword">const</span> std::string& func_name,</div>
|
||||
<div class="line"><a id="l00088" name="l00088"></a><span class="lineno"> 88</span> <span class="keyword">const</span> std::string& op_name,</div>
|
||||
<div class="line"><a id="l00089" name="l00089"></a><span class="lineno"> 89</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& in,</div>
|
||||
<div class="line"><a id="l00090" name="l00090"></a><span class="lineno"> 90</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00091" name="l00091"></a><span class="lineno"> 91</span> <span class="keywordtype">int</span> ndim = -1,</div>
|
||||
<div class="line"><a id="l00092" name="l00092"></a><span class="lineno"> 92</span> <span class="keywordtype">int</span> bm = -1,</div>
|
||||
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"> 93</span> <span class="keywordtype">int</span> bn = -1);</div>
|
||||
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"> 94</span> </div>
|
||||
<div class="line"><a id="l00095" name="l00095"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a84fa8e0aee321a9d614433a0b933103b"> 95</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a84fa8e0aee321a9d614433a0b933103b">get_steel_gemm_fused_kernel</a>(</div>
|
||||
<div class="line"><a id="l00096" name="l00096"></a><span class="lineno"> 96</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00097" name="l00097"></a><span class="lineno"> 97</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00098" name="l00098"></a><span class="lineno"> 98</span> <span class="keyword">const</span> std::string& hash_name,</div>
|
||||
<div class="line"><a id="l00099" name="l00099"></a><span class="lineno"> 99</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">metal::MTLFCList</a>& func_consts,</div>
|
||||
<div class="line"><a id="l00100" name="l00100"></a><span class="lineno"> 100</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"> 101</span> <span class="keywordtype">bool</span> transpose_a,</div>
|
||||
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"> 102</span> <span class="keywordtype">bool</span> transpose_b,</div>
|
||||
<div class="line"><a id="l00103" name="l00103"></a><span class="lineno"> 103</span> <span class="keywordtype">int</span> bm,</div>
|
||||
<div class="line"><a id="l00104" name="l00104"></a><span class="lineno"> 104</span> <span class="keywordtype">int</span> bn,</div>
|
||||
<div class="line"><a id="l00105" name="l00105"></a><span class="lineno"> 105</span> <span class="keywordtype">int</span> bk,</div>
|
||||
<div class="line"><a id="l00106" name="l00106"></a><span class="lineno"> 106</span> <span class="keywordtype">int</span> wm,</div>
|
||||
<div class="line"><a id="l00107" name="l00107"></a><span class="lineno"> 107</span> <span class="keywordtype">int</span> wn);</div>
|
||||
<div class="line"><a id="l00108" name="l00108"></a><span class="lineno"> 108</span> </div>
|
||||
<div class="line"><a id="l00109" name="l00109"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#af48c6f2f72b61dbd6766e4f5fea85df5"> 109</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#af48c6f2f72b61dbd6766e4f5fea85df5">get_steel_gemm_splitk_kernel</a>(</div>
|
||||
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"> 110</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"> 111</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00112" name="l00112"></a><span class="lineno"> 112</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& in,</div>
|
||||
<div class="line"><a id="l00113" name="l00113"></a><span class="lineno"> 113</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00114" name="l00114"></a><span class="lineno"> 114</span> <span class="keywordtype">bool</span> transpose_a,</div>
|
||||
<div class="line"><a id="l00115" name="l00115"></a><span class="lineno"> 115</span> <span class="keywordtype">bool</span> transpose_b,</div>
|
||||
<div class="line"><a id="l00116" name="l00116"></a><span class="lineno"> 116</span> <span class="keywordtype">int</span> bm,</div>
|
||||
<div class="line"><a id="l00117" name="l00117"></a><span class="lineno"> 117</span> <span class="keywordtype">int</span> bn,</div>
|
||||
<div class="line"><a id="l00118" name="l00118"></a><span class="lineno"> 118</span> <span class="keywordtype">int</span> bk,</div>
|
||||
<div class="line"><a id="l00119" name="l00119"></a><span class="lineno"> 119</span> <span class="keywordtype">int</span> wm,</div>
|
||||
<div class="line"><a id="l00120" name="l00120"></a><span class="lineno"> 120</span> <span class="keywordtype">int</span> wn,</div>
|
||||
<div class="line"><a id="l00121" name="l00121"></a><span class="lineno"> 121</span> <span class="keywordtype">bool</span> mn_aligned,</div>
|
||||
<div class="line"><a id="l00122" name="l00122"></a><span class="lineno"> 122</span> <span class="keywordtype">bool</span> k_aligned);</div>
|
||||
<div class="line"><a id="l00123" name="l00123"></a><span class="lineno"> 123</span> </div>
|
||||
<div class="line"><a id="l00124" name="l00124"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a195b86cad5bb99aa1bcd23952305af6b"> 124</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a195b86cad5bb99aa1bcd23952305af6b">get_steel_gemm_splitk_accum_kernel</a>(</div>
|
||||
<div class="line"><a id="l00125" name="l00125"></a><span class="lineno"> 125</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00126" name="l00126"></a><span class="lineno"> 126</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00127" name="l00127"></a><span class="lineno"> 127</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& in,</div>
|
||||
<div class="line"><a id="l00128" name="l00128"></a><span class="lineno"> 128</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00129" name="l00129"></a><span class="lineno"> 129</span> <span class="keywordtype">bool</span> axbpy);</div>
|
||||
<div class="line"><a id="l00130" name="l00130"></a><span class="lineno"> 130</span> </div>
|
||||
<div class="line"><a id="l00131" name="l00131"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#ab5f60614e965144b451930fdf935e08d"> 131</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#ab5f60614e965144b451930fdf935e08d">get_steel_gemm_masked_kernel</a>(</div>
|
||||
<div class="line"><a id="l00132" name="l00132"></a><span class="lineno"> 132</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00133" name="l00133"></a><span class="lineno"> 133</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00134" name="l00134"></a><span class="lineno"> 134</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00135" name="l00135"></a><span class="lineno"> 135</span> <span class="keyword">const</span> std::optional<array>& mask_out,</div>
|
||||
<div class="line"><a id="l00136" name="l00136"></a><span class="lineno"> 136</span> <span class="keyword">const</span> std::optional<array>& mask_op,</div>
|
||||
<div class="line"><a id="l00137" name="l00137"></a><span class="lineno"> 137</span> <span class="keywordtype">bool</span> transpose_a,</div>
|
||||
<div class="line"><a id="l00138" name="l00138"></a><span class="lineno"> 138</span> <span class="keywordtype">bool</span> transpose_b,</div>
|
||||
<div class="line"><a id="l00139" name="l00139"></a><span class="lineno"> 139</span> <span class="keywordtype">int</span> bm,</div>
|
||||
<div class="line"><a id="l00140" name="l00140"></a><span class="lineno"> 140</span> <span class="keywordtype">int</span> bn,</div>
|
||||
<div class="line"><a id="l00141" name="l00141"></a><span class="lineno"> 141</span> <span class="keywordtype">int</span> bk,</div>
|
||||
<div class="line"><a id="l00142" name="l00142"></a><span class="lineno"> 142</span> <span class="keywordtype">int</span> wm,</div>
|
||||
<div class="line"><a id="l00143" name="l00143"></a><span class="lineno"> 143</span> <span class="keywordtype">int</span> wn,</div>
|
||||
<div class="line"><a id="l00144" name="l00144"></a><span class="lineno"> 144</span> <span class="keywordtype">bool</span> mn_aligned,</div>
|
||||
<div class="line"><a id="l00145" name="l00145"></a><span class="lineno"> 145</span> <span class="keywordtype">bool</span> k_aligned);</div>
|
||||
<div class="line"><a id="l00146" name="l00146"></a><span class="lineno"> 146</span> </div>
|
||||
<div class="line"><a id="l00147" name="l00147"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#adce79d220672f5f3c65cc31d145ca9c4"> 147</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#adce79d220672f5f3c65cc31d145ca9c4">get_steel_conv_kernel</a>(</div>
|
||||
<div class="line"><a id="l00148" name="l00148"></a><span class="lineno"> 148</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00149" name="l00149"></a><span class="lineno"> 149</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00150" name="l00150"></a><span class="lineno"> 150</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00151" name="l00151"></a><span class="lineno"> 151</span> <span class="keywordtype">int</span> bm,</div>
|
||||
<div class="line"><a id="l00152" name="l00152"></a><span class="lineno"> 152</span> <span class="keywordtype">int</span> bn,</div>
|
||||
<div class="line"><a id="l00153" name="l00153"></a><span class="lineno"> 153</span> <span class="keywordtype">int</span> bk,</div>
|
||||
<div class="line"><a id="l00154" name="l00154"></a><span class="lineno"> 154</span> <span class="keywordtype">int</span> wm,</div>
|
||||
<div class="line"><a id="l00155" name="l00155"></a><span class="lineno"> 155</span> <span class="keywordtype">int</span> wn,</div>
|
||||
<div class="line"><a id="l00156" name="l00156"></a><span class="lineno"> 156</span> <span class="keywordtype">int</span> n_channel_specialization,</div>
|
||||
<div class="line"><a id="l00157" name="l00157"></a><span class="lineno"> 157</span> <span class="keywordtype">bool</span> small_filter);</div>
|
||||
<div class="line"><a id="l00158" name="l00158"></a><span class="lineno"> 158</span> </div>
|
||||
<div class="line"><a id="l00159" name="l00159"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a90c24e0d0b99b68fad9deefcf4d3e818"> 159</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a90c24e0d0b99b68fad9deefcf4d3e818">get_gemv_masked_kernel</a>(</div>
|
||||
<div class="line"><a id="l00160" name="l00160"></a><span class="lineno"> 160</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00161" name="l00161"></a><span class="lineno"> 161</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00162" name="l00162"></a><span class="lineno"> 162</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00163" name="l00163"></a><span class="lineno"> 163</span> <span class="keyword">const</span> std::optional<array>& mask_out,</div>
|
||||
<div class="line"><a id="l00164" name="l00164"></a><span class="lineno"> 164</span> <span class="keyword">const</span> std::optional<array>& mask_op,</div>
|
||||
<div class="line"><a id="l00165" name="l00165"></a><span class="lineno"> 165</span> <span class="keywordtype">bool</span> transpose_mat,</div>
|
||||
<div class="line"><a id="l00166" name="l00166"></a><span class="lineno"> 166</span> <span class="keywordtype">int</span> bm,</div>
|
||||
<div class="line"><a id="l00167" name="l00167"></a><span class="lineno"> 167</span> <span class="keywordtype">int</span> bn,</div>
|
||||
<div class="line"><a id="l00168" name="l00168"></a><span class="lineno"> 168</span> <span class="keywordtype">int</span> sm,</div>
|
||||
<div class="line"><a id="l00169" name="l00169"></a><span class="lineno"> 169</span> <span class="keywordtype">int</span> sn,</div>
|
||||
<div class="line"><a id="l00170" name="l00170"></a><span class="lineno"> 170</span> <span class="keywordtype">int</span> tm,</div>
|
||||
<div class="line"><a id="l00171" name="l00171"></a><span class="lineno"> 171</span> <span class="keywordtype">int</span> tn,</div>
|
||||
<div class="line"><a id="l00172" name="l00172"></a><span class="lineno"> 172</span> <span class="keywordtype">bool</span> contiguous);</div>
|
||||
<div class="line"><a id="l00173" name="l00173"></a><span class="lineno"> 173</span> </div>
|
||||
<div class="line"><a id="l00174" name="l00174"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#abce2b67044ee06a7bbe7a91ec7c8c48d"> 174</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#abce2b67044ee06a7bbe7a91ec7c8c48d">get_steel_conv_general_kernel</a>(</div>
|
||||
<div class="line"><a id="l00175" name="l00175"></a><span class="lineno"> 175</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00176" name="l00176"></a><span class="lineno"> 176</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00177" name="l00177"></a><span class="lineno"> 177</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00178" name="l00178"></a><span class="lineno"> 178</span> <span class="keywordtype">int</span> bm,</div>
|
||||
<div class="line"><a id="l00179" name="l00179"></a><span class="lineno"> 179</span> <span class="keywordtype">int</span> bn,</div>
|
||||
<div class="line"><a id="l00180" name="l00180"></a><span class="lineno"> 180</span> <span class="keywordtype">int</span> bk,</div>
|
||||
<div class="line"><a id="l00181" name="l00181"></a><span class="lineno"> 181</span> <span class="keywordtype">int</span> wm,</div>
|
||||
<div class="line"><a id="l00182" name="l00182"></a><span class="lineno"> 182</span> <span class="keywordtype">int</span> wn);</div>
|
||||
<div class="line"><a id="l00183" name="l00183"></a><span class="lineno"> 183</span> </div>
|
||||
<div class="line"><a id="l00184" name="l00184"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a1d4cffc3c78067b3d9a62d64f3fb686f"> 184</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a1d4cffc3c78067b3d9a62d64f3fb686f">get_fft_kernel</a>(</div>
|
||||
<div class="line"><a id="l00185" name="l00185"></a><span class="lineno"> 185</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00186" name="l00186"></a><span class="lineno"> 186</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00187" name="l00187"></a><span class="lineno"> 187</span> <span class="keyword">const</span> std::string& hash_name,</div>
|
||||
<div class="line"><a id="l00188" name="l00188"></a><span class="lineno"> 188</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">metal::MTLFCList</a>& func_consts,</div>
|
||||
<div class="line"><a id="l00189" name="l00189"></a><span class="lineno"> 189</span> <span class="keyword">const</span> std::string& template_def);</div>
|
||||
<div class="line"><a id="l00190" name="l00190"></a><span class="lineno"> 190</span> </div>
|
||||
<div class="line"><a id="l00191" name="l00191"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#aa3faeae5378bfaafe3ce3432a051e43e"> 191</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#aa3faeae5378bfaafe3ce3432a051e43e">get_quantized_kernel</a>(</div>
|
||||
<div class="line"><a id="l00192" name="l00192"></a><span class="lineno"> 192</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00193" name="l00193"></a><span class="lineno"> 193</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00194" name="l00194"></a><span class="lineno"> 194</span> <span class="keyword">const</span> std::string& template_def);</div>
|
||||
<div class="line"><a id="l00195" name="l00195"></a><span class="lineno"> 195</span> </div>
|
||||
<div class="line"><a id="l00196" name="l00196"></a><span class="lineno"> 196</span><span class="comment">// Create a GPU kernel template definition for JIT compilation</span></div>
|
||||
<div class="line"><a id="l00197" name="l00197"></a><span class="lineno"> 197</span><span class="keyword">template</span> <<span class="keyword">typename</span>... Args></div>
|
||||
<div class="line"><a id="l00198" name="l00198"></a><span class="lineno"> 198</span>std::string</div>
|
||||
<div class="foldopen" id="foldopen00199" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00199" name="l00199"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#aae0d19f0acdef2accd2428fb84c8a032"> 199</a></span><a class="code hl_function" href="namespacemlx_1_1core.html#aae0d19f0acdef2accd2428fb84c8a032">get_template_definition</a>(std::string name, std::string func, Args... args) {</div>
|
||||
<div class="line"><a id="l00200" name="l00200"></a><span class="lineno"> 200</span> std::ostringstream s;</div>
|
||||
<div class="line"><a id="l00201" name="l00201"></a><span class="lineno"> 201</span> s << func << <span class="stringliteral">"<"</span>;</div>
|
||||
<div class="line"><a id="l00202" name="l00202"></a><span class="lineno"> 202</span> <span class="keywordtype">bool</span> first = <span class="keyword">true</span>;</div>
|
||||
<div class="line"><a id="l00203" name="l00203"></a><span class="lineno"> 203</span> <span class="keyword">auto</span> add_arg = [&s, &first](<span class="keyword">const</span> <span class="keyword">auto</span>& arg) {</div>
|
||||
<div class="line"><a id="l00204" name="l00204"></a><span class="lineno"> 204</span> <span class="keywordflow">if</span> (!first) {</div>
|
||||
<div class="line"><a id="l00205" name="l00205"></a><span class="lineno"> 205</span> s << <span class="stringliteral">", "</span>;</div>
|
||||
<div class="line"><a id="l00206" name="l00206"></a><span class="lineno"> 206</span> }</div>
|
||||
<div class="line"><a id="l00207" name="l00207"></a><span class="lineno"> 207</span> first = <span class="keyword">false</span>;</div>
|
||||
<div class="line"><a id="l00208" name="l00208"></a><span class="lineno"> 208</span> s << arg;</div>
|
||||
<div class="line"><a id="l00209" name="l00209"></a><span class="lineno"> 209</span> };</div>
|
||||
<div class="line"><a id="l00210" name="l00210"></a><span class="lineno"> 210</span> (add_arg(args), ...);</div>
|
||||
<div class="line"><a id="l00211" name="l00211"></a><span class="lineno"> 211</span> s << <span class="stringliteral">">"</span>;</div>
|
||||
<div class="line"><a id="l00212" name="l00212"></a><span class="lineno"> 212</span> <span class="keywordflow">return</span> fmt::format(</div>
|
||||
<div class="line"><a id="l00213" name="l00213"></a><span class="lineno"> 213</span> <span class="stringliteral">"\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n"</span>,</div>
|
||||
<div class="line"><a id="l00214" name="l00214"></a><span class="lineno"> 214</span> name,</div>
|
||||
<div class="line"><a id="l00215" name="l00215"></a><span class="lineno"> 215</span> s.str());</div>
|
||||
<div class="line"><a id="l00216" name="l00216"></a><span class="lineno"> 216</span>}</div>
|
||||
<div class="line"><a id="l00082" name="l00082"></a><span class="lineno"> 82</span> <span class="keyword">const</span> std::string& func_name,</div>
|
||||
<div class="line"><a id="l00083" name="l00083"></a><span class="lineno"> 83</span> <span class="keyword">const</span> std::string& op_name,</div>
|
||||
<div class="line"><a id="l00084" name="l00084"></a><span class="lineno"> 84</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out);</div>
|
||||
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"> 85</span> </div>
|
||||
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b"> 86</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b">get_reduce_kernel</a>(</div>
|
||||
<div class="line"><a id="l00087" name="l00087"></a><span class="lineno"> 87</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00088" name="l00088"></a><span class="lineno"> 88</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00089" name="l00089"></a><span class="lineno"> 89</span> <span class="keyword">const</span> std::string& func_name,</div>
|
||||
<div class="line"><a id="l00090" name="l00090"></a><span class="lineno"> 90</span> <span class="keyword">const</span> std::string& op_name,</div>
|
||||
<div class="line"><a id="l00091" name="l00091"></a><span class="lineno"> 91</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& in,</div>
|
||||
<div class="line"><a id="l00092" name="l00092"></a><span class="lineno"> 92</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"> 93</span> <span class="keywordtype">int</span> ndim = -1,</div>
|
||||
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"> 94</span> <span class="keywordtype">int</span> bm = -1,</div>
|
||||
<div class="line"><a id="l00095" name="l00095"></a><span class="lineno"> 95</span> <span class="keywordtype">int</span> bn = -1);</div>
|
||||
<div class="line"><a id="l00096" name="l00096"></a><span class="lineno"> 96</span> </div>
|
||||
<div class="line"><a id="l00097" name="l00097"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a84fa8e0aee321a9d614433a0b933103b"> 97</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a84fa8e0aee321a9d614433a0b933103b">get_steel_gemm_fused_kernel</a>(</div>
|
||||
<div class="line"><a id="l00098" name="l00098"></a><span class="lineno"> 98</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00099" name="l00099"></a><span class="lineno"> 99</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00100" name="l00100"></a><span class="lineno"> 100</span> <span class="keyword">const</span> std::string& hash_name,</div>
|
||||
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"> 101</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">metal::MTLFCList</a>& func_consts,</div>
|
||||
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"> 102</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00103" name="l00103"></a><span class="lineno"> 103</span> <span class="keywordtype">bool</span> transpose_a,</div>
|
||||
<div class="line"><a id="l00104" name="l00104"></a><span class="lineno"> 104</span> <span class="keywordtype">bool</span> transpose_b,</div>
|
||||
<div class="line"><a id="l00105" name="l00105"></a><span class="lineno"> 105</span> <span class="keywordtype">int</span> bm,</div>
|
||||
<div class="line"><a id="l00106" name="l00106"></a><span class="lineno"> 106</span> <span class="keywordtype">int</span> bn,</div>
|
||||
<div class="line"><a id="l00107" name="l00107"></a><span class="lineno"> 107</span> <span class="keywordtype">int</span> bk,</div>
|
||||
<div class="line"><a id="l00108" name="l00108"></a><span class="lineno"> 108</span> <span class="keywordtype">int</span> wm,</div>
|
||||
<div class="line"><a id="l00109" name="l00109"></a><span class="lineno"> 109</span> <span class="keywordtype">int</span> wn);</div>
|
||||
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"> 110</span> </div>
|
||||
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#af48c6f2f72b61dbd6766e4f5fea85df5"> 111</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#af48c6f2f72b61dbd6766e4f5fea85df5">get_steel_gemm_splitk_kernel</a>(</div>
|
||||
<div class="line"><a id="l00112" name="l00112"></a><span class="lineno"> 112</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00113" name="l00113"></a><span class="lineno"> 113</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00114" name="l00114"></a><span class="lineno"> 114</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& in,</div>
|
||||
<div class="line"><a id="l00115" name="l00115"></a><span class="lineno"> 115</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00116" name="l00116"></a><span class="lineno"> 116</span> <span class="keywordtype">bool</span> transpose_a,</div>
|
||||
<div class="line"><a id="l00117" name="l00117"></a><span class="lineno"> 117</span> <span class="keywordtype">bool</span> transpose_b,</div>
|
||||
<div class="line"><a id="l00118" name="l00118"></a><span class="lineno"> 118</span> <span class="keywordtype">int</span> bm,</div>
|
||||
<div class="line"><a id="l00119" name="l00119"></a><span class="lineno"> 119</span> <span class="keywordtype">int</span> bn,</div>
|
||||
<div class="line"><a id="l00120" name="l00120"></a><span class="lineno"> 120</span> <span class="keywordtype">int</span> bk,</div>
|
||||
<div class="line"><a id="l00121" name="l00121"></a><span class="lineno"> 121</span> <span class="keywordtype">int</span> wm,</div>
|
||||
<div class="line"><a id="l00122" name="l00122"></a><span class="lineno"> 122</span> <span class="keywordtype">int</span> wn,</div>
|
||||
<div class="line"><a id="l00123" name="l00123"></a><span class="lineno"> 123</span> <span class="keywordtype">bool</span> mn_aligned,</div>
|
||||
<div class="line"><a id="l00124" name="l00124"></a><span class="lineno"> 124</span> <span class="keywordtype">bool</span> k_aligned);</div>
|
||||
<div class="line"><a id="l00125" name="l00125"></a><span class="lineno"> 125</span> </div>
|
||||
<div class="line"><a id="l00126" name="l00126"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a195b86cad5bb99aa1bcd23952305af6b"> 126</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a195b86cad5bb99aa1bcd23952305af6b">get_steel_gemm_splitk_accum_kernel</a>(</div>
|
||||
<div class="line"><a id="l00127" name="l00127"></a><span class="lineno"> 127</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00128" name="l00128"></a><span class="lineno"> 128</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00129" name="l00129"></a><span class="lineno"> 129</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& in,</div>
|
||||
<div class="line"><a id="l00130" name="l00130"></a><span class="lineno"> 130</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00131" name="l00131"></a><span class="lineno"> 131</span> <span class="keywordtype">bool</span> axbpy);</div>
|
||||
<div class="line"><a id="l00132" name="l00132"></a><span class="lineno"> 132</span> </div>
|
||||
<div class="line"><a id="l00133" name="l00133"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#ab5f60614e965144b451930fdf935e08d"> 133</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#ab5f60614e965144b451930fdf935e08d">get_steel_gemm_masked_kernel</a>(</div>
|
||||
<div class="line"><a id="l00134" name="l00134"></a><span class="lineno"> 134</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00135" name="l00135"></a><span class="lineno"> 135</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00136" name="l00136"></a><span class="lineno"> 136</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00137" name="l00137"></a><span class="lineno"> 137</span> <span class="keyword">const</span> std::optional<array>& mask_out,</div>
|
||||
<div class="line"><a id="l00138" name="l00138"></a><span class="lineno"> 138</span> <span class="keyword">const</span> std::optional<array>& mask_op,</div>
|
||||
<div class="line"><a id="l00139" name="l00139"></a><span class="lineno"> 139</span> <span class="keywordtype">bool</span> transpose_a,</div>
|
||||
<div class="line"><a id="l00140" name="l00140"></a><span class="lineno"> 140</span> <span class="keywordtype">bool</span> transpose_b,</div>
|
||||
<div class="line"><a id="l00141" name="l00141"></a><span class="lineno"> 141</span> <span class="keywordtype">int</span> bm,</div>
|
||||
<div class="line"><a id="l00142" name="l00142"></a><span class="lineno"> 142</span> <span class="keywordtype">int</span> bn,</div>
|
||||
<div class="line"><a id="l00143" name="l00143"></a><span class="lineno"> 143</span> <span class="keywordtype">int</span> bk,</div>
|
||||
<div class="line"><a id="l00144" name="l00144"></a><span class="lineno"> 144</span> <span class="keywordtype">int</span> wm,</div>
|
||||
<div class="line"><a id="l00145" name="l00145"></a><span class="lineno"> 145</span> <span class="keywordtype">int</span> wn,</div>
|
||||
<div class="line"><a id="l00146" name="l00146"></a><span class="lineno"> 146</span> <span class="keywordtype">bool</span> mn_aligned,</div>
|
||||
<div class="line"><a id="l00147" name="l00147"></a><span class="lineno"> 147</span> <span class="keywordtype">bool</span> k_aligned);</div>
|
||||
<div class="line"><a id="l00148" name="l00148"></a><span class="lineno"> 148</span> </div>
|
||||
<div class="line"><a id="l00149" name="l00149"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#adce79d220672f5f3c65cc31d145ca9c4"> 149</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#adce79d220672f5f3c65cc31d145ca9c4">get_steel_conv_kernel</a>(</div>
|
||||
<div class="line"><a id="l00150" name="l00150"></a><span class="lineno"> 150</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00151" name="l00151"></a><span class="lineno"> 151</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00152" name="l00152"></a><span class="lineno"> 152</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00153" name="l00153"></a><span class="lineno"> 153</span> <span class="keywordtype">int</span> bm,</div>
|
||||
<div class="line"><a id="l00154" name="l00154"></a><span class="lineno"> 154</span> <span class="keywordtype">int</span> bn,</div>
|
||||
<div class="line"><a id="l00155" name="l00155"></a><span class="lineno"> 155</span> <span class="keywordtype">int</span> bk,</div>
|
||||
<div class="line"><a id="l00156" name="l00156"></a><span class="lineno"> 156</span> <span class="keywordtype">int</span> wm,</div>
|
||||
<div class="line"><a id="l00157" name="l00157"></a><span class="lineno"> 157</span> <span class="keywordtype">int</span> wn,</div>
|
||||
<div class="line"><a id="l00158" name="l00158"></a><span class="lineno"> 158</span> <span class="keywordtype">int</span> n_channel_specialization,</div>
|
||||
<div class="line"><a id="l00159" name="l00159"></a><span class="lineno"> 159</span> <span class="keywordtype">bool</span> small_filter);</div>
|
||||
<div class="line"><a id="l00160" name="l00160"></a><span class="lineno"> 160</span> </div>
|
||||
<div class="line"><a id="l00161" name="l00161"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a90c24e0d0b99b68fad9deefcf4d3e818"> 161</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a90c24e0d0b99b68fad9deefcf4d3e818">get_gemv_masked_kernel</a>(</div>
|
||||
<div class="line"><a id="l00162" name="l00162"></a><span class="lineno"> 162</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00163" name="l00163"></a><span class="lineno"> 163</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00164" name="l00164"></a><span class="lineno"> 164</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00165" name="l00165"></a><span class="lineno"> 165</span> <span class="keyword">const</span> std::optional<array>& mask_out,</div>
|
||||
<div class="line"><a id="l00166" name="l00166"></a><span class="lineno"> 166</span> <span class="keyword">const</span> std::optional<array>& mask_op,</div>
|
||||
<div class="line"><a id="l00167" name="l00167"></a><span class="lineno"> 167</span> <span class="keywordtype">bool</span> transpose_mat,</div>
|
||||
<div class="line"><a id="l00168" name="l00168"></a><span class="lineno"> 168</span> <span class="keywordtype">int</span> bm,</div>
|
||||
<div class="line"><a id="l00169" name="l00169"></a><span class="lineno"> 169</span> <span class="keywordtype">int</span> bn,</div>
|
||||
<div class="line"><a id="l00170" name="l00170"></a><span class="lineno"> 170</span> <span class="keywordtype">int</span> sm,</div>
|
||||
<div class="line"><a id="l00171" name="l00171"></a><span class="lineno"> 171</span> <span class="keywordtype">int</span> sn,</div>
|
||||
<div class="line"><a id="l00172" name="l00172"></a><span class="lineno"> 172</span> <span class="keywordtype">int</span> tm,</div>
|
||||
<div class="line"><a id="l00173" name="l00173"></a><span class="lineno"> 173</span> <span class="keywordtype">int</span> tn,</div>
|
||||
<div class="line"><a id="l00174" name="l00174"></a><span class="lineno"> 174</span> <span class="keywordtype">bool</span> contiguous);</div>
|
||||
<div class="line"><a id="l00175" name="l00175"></a><span class="lineno"> 175</span> </div>
|
||||
<div class="line"><a id="l00176" name="l00176"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#abce2b67044ee06a7bbe7a91ec7c8c48d"> 176</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#abce2b67044ee06a7bbe7a91ec7c8c48d">get_steel_conv_general_kernel</a>(</div>
|
||||
<div class="line"><a id="l00177" name="l00177"></a><span class="lineno"> 177</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00178" name="l00178"></a><span class="lineno"> 178</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00179" name="l00179"></a><span class="lineno"> 179</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00180" name="l00180"></a><span class="lineno"> 180</span> <span class="keywordtype">int</span> bm,</div>
|
||||
<div class="line"><a id="l00181" name="l00181"></a><span class="lineno"> 181</span> <span class="keywordtype">int</span> bn,</div>
|
||||
<div class="line"><a id="l00182" name="l00182"></a><span class="lineno"> 182</span> <span class="keywordtype">int</span> bk,</div>
|
||||
<div class="line"><a id="l00183" name="l00183"></a><span class="lineno"> 183</span> <span class="keywordtype">int</span> wm,</div>
|
||||
<div class="line"><a id="l00184" name="l00184"></a><span class="lineno"> 184</span> <span class="keywordtype">int</span> wn);</div>
|
||||
<div class="line"><a id="l00185" name="l00185"></a><span class="lineno"> 185</span> </div>
|
||||
<div class="line"><a id="l00186" name="l00186"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a1d4cffc3c78067b3d9a62d64f3fb686f"> 186</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a1d4cffc3c78067b3d9a62d64f3fb686f">get_fft_kernel</a>(</div>
|
||||
<div class="line"><a id="l00187" name="l00187"></a><span class="lineno"> 187</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00188" name="l00188"></a><span class="lineno"> 188</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00189" name="l00189"></a><span class="lineno"> 189</span> <span class="keyword">const</span> std::string& hash_name,</div>
|
||||
<div class="line"><a id="l00190" name="l00190"></a><span class="lineno"> 190</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">metal::MTLFCList</a>& func_consts,</div>
|
||||
<div class="line"><a id="l00191" name="l00191"></a><span class="lineno"> 191</span> <span class="keyword">const</span> std::string& template_def);</div>
|
||||
<div class="line"><a id="l00192" name="l00192"></a><span class="lineno"> 192</span> </div>
|
||||
<div class="line"><a id="l00193" name="l00193"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#aa3faeae5378bfaafe3ce3432a051e43e"> 193</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#aa3faeae5378bfaafe3ce3432a051e43e">get_quantized_kernel</a>(</div>
|
||||
<div class="line"><a id="l00194" name="l00194"></a><span class="lineno"> 194</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00195" name="l00195"></a><span class="lineno"> 195</span> <span class="keyword">const</span> std::string& kernel_name,</div>
|
||||
<div class="line"><a id="l00196" name="l00196"></a><span class="lineno"> 196</span> <span class="keyword">const</span> std::string& template_def);</div>
|
||||
<div class="line"><a id="l00197" name="l00197"></a><span class="lineno"> 197</span> </div>
|
||||
<div class="line"><a id="l00198" name="l00198"></a><span class="lineno"> 198</span><span class="comment">// Create a GPU kernel template definition for JIT compilation</span></div>
|
||||
<div class="line"><a id="l00199" name="l00199"></a><span class="lineno"> 199</span><span class="keyword">template</span> <<span class="keyword">typename</span>... Args></div>
|
||||
<div class="line"><a id="l00200" name="l00200"></a><span class="lineno"> 200</span>std::string</div>
|
||||
<div class="foldopen" id="foldopen00201" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00201" name="l00201"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#aae0d19f0acdef2accd2428fb84c8a032"> 201</a></span><a class="code hl_function" href="namespacemlx_1_1core.html#aae0d19f0acdef2accd2428fb84c8a032">get_template_definition</a>(std::string name, std::string func, Args... args) {</div>
|
||||
<div class="line"><a id="l00202" name="l00202"></a><span class="lineno"> 202</span> std::ostringstream s;</div>
|
||||
<div class="line"><a id="l00203" name="l00203"></a><span class="lineno"> 203</span> s << func << <span class="stringliteral">"<"</span>;</div>
|
||||
<div class="line"><a id="l00204" name="l00204"></a><span class="lineno"> 204</span> <span class="keywordtype">bool</span> first = <span class="keyword">true</span>;</div>
|
||||
<div class="line"><a id="l00205" name="l00205"></a><span class="lineno"> 205</span> <span class="keyword">auto</span> add_arg = [&s, &first](<span class="keyword">const</span> <span class="keyword">auto</span>& arg) {</div>
|
||||
<div class="line"><a id="l00206" name="l00206"></a><span class="lineno"> 206</span> <span class="keywordflow">if</span> (!first) {</div>
|
||||
<div class="line"><a id="l00207" name="l00207"></a><span class="lineno"> 207</span> s << <span class="stringliteral">", "</span>;</div>
|
||||
<div class="line"><a id="l00208" name="l00208"></a><span class="lineno"> 208</span> }</div>
|
||||
<div class="line"><a id="l00209" name="l00209"></a><span class="lineno"> 209</span> first = <span class="keyword">false</span>;</div>
|
||||
<div class="line"><a id="l00210" name="l00210"></a><span class="lineno"> 210</span> s << arg;</div>
|
||||
<div class="line"><a id="l00211" name="l00211"></a><span class="lineno"> 211</span> };</div>
|
||||
<div class="line"><a id="l00212" name="l00212"></a><span class="lineno"> 212</span> (add_arg(args), ...);</div>
|
||||
<div class="line"><a id="l00213" name="l00213"></a><span class="lineno"> 213</span> s << <span class="stringliteral">">"</span>;</div>
|
||||
<div class="line"><a id="l00214" name="l00214"></a><span class="lineno"> 214</span> <span class="keywordflow">return</span> fmt::format(</div>
|
||||
<div class="line"><a id="l00215" name="l00215"></a><span class="lineno"> 215</span> <span class="stringliteral">"\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n"</span>,</div>
|
||||
<div class="line"><a id="l00216" name="l00216"></a><span class="lineno"> 216</span> name,</div>
|
||||
<div class="line"><a id="l00217" name="l00217"></a><span class="lineno"> 217</span> s.str());</div>
|
||||
<div class="line"><a id="l00218" name="l00218"></a><span class="lineno"> 218</span>}</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00217" name="l00217"></a><span class="lineno"> 217</span> </div>
|
||||
<div class="line"><a id="l00218" name="l00218"></a><span class="lineno"> 218</span>} <span class="comment">// namespace mlx::core</span></div>
|
||||
<div class="line"><a id="l00219" name="l00219"></a><span class="lineno"> 219</span> </div>
|
||||
<div class="line"><a id="l00220" name="l00220"></a><span class="lineno"> 220</span>} <span class="comment">// namespace mlx::core</span></div>
|
||||
<div class="ttc" id="aarray_8h_html"><div class="ttname"><a href="array_8h.html">array.h</a></div></div>
|
||||
<div class="ttc" id="abackend_2metal_2device_8h_html"><div class="ttname"><a href="backend_2metal_2device_8h.html">device.h</a></div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1array_html"><div class="ttname"><a href="classmlx_1_1core_1_1array.html">mlx::core::array</a></div><div class="ttdef"><b>Definition</b> array.h:20</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html">mlx::core::metal::Device</a></div><div class="ttdef"><b>Definition</b> device.h:128</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html">mlx::core::metal::Device</a></div><div class="ttdef"><b>Definition</b> device.h:131</div></div>
|
||||
<div class="ttc" id="acommon_2binary_8h_html_a70228731d29946574b238d21fb4b360c"><div class="ttname"><a href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a></div><div class="ttdeci">Op op</div><div class="ttdef"><b>Definition</b> binary.h:129</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_1_1metal_html_a616e09a1ef321d527770721cef264c54"><div class="ttname"><a href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">mlx::core::metal::MTLFCList</a></div><div class="ttdeci">std::vector< std::tuple< const void *, MTL::DataType, NS::UInteger > > MTLFCList</div><div class="ttdef"><b>Definition</b> device.h:38</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html"><div class="ttname"><a href="namespacemlx_1_1core.html">mlx::core</a></div><div class="ttdef"><b>Definition</b> allocator.h:7</div></div>
|
||||
@ -322,9 +324,9 @@ $(function(){ initResizable(false); });
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a195b86cad5bb99aa1bcd23952305af6b"><div class="ttname"><a href="namespacemlx_1_1core.html#a195b86cad5bb99aa1bcd23952305af6b">mlx::core::get_steel_gemm_splitk_accum_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_steel_gemm_splitk_accum_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool axbpy)</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a1d4cffc3c78067b3d9a62d64f3fb686f"><div class="ttname"><a href="namespacemlx_1_1core.html#a1d4cffc3c78067b3d9a62d64f3fb686f">mlx::core::get_fft_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_fft_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const std::string &template_def)</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a35a412f688d79eb47e42d20a7c8650ee"><div class="ttname"><a href="namespacemlx_1_1core.html#a35a412f688d79eb47e42d20a7c8650ee">mlx::core::get_softmax_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_softmax_kernel(metal::Device &d, const std::string &kernel_name, bool precise, const array &out)</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a3bd386cb6db09f636963ce66ceaf8647"><div class="ttname"><a href="namespacemlx_1_1core.html#a3bd386cb6db09f636963ce66ceaf8647">mlx::core::get_reduce_init_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_reduce_init_kernel(metal::Device &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const array &out)</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a4decd4a07d91487e6903f6e3c8b7513a"><div class="ttname"><a href="namespacemlx_1_1core.html#a4decd4a07d91487e6903f6e3c8b7513a">mlx::core::get_binary_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_binary_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a4e809746f48e5dcf7fa63215d3f5e33e"><div class="ttname"><a href="namespacemlx_1_1core.html#a4e809746f48e5dcf7fa63215d3f5e33e">mlx::core::get_binary_two_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_binary_two_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a51c4bb09230348bd0252e22bfdc9bc89"><div class="ttname"><a href="namespacemlx_1_1core.html#a51c4bb09230348bd0252e22bfdc9bc89">mlx::core::get_reduce_init_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_reduce_init_kernel(metal::Device &d, const std::string &kernel_name, const array &out)</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a54eb3b65375022428aab5f810e40624b"><div class="ttname"><a href="namespacemlx_1_1core.html#a54eb3b65375022428aab5f810e40624b">mlx::core::get_ternary_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_ternary_kernel(metal::Device &d, const std::string &kernel_name, Dtype type, const std::string op)</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a76f614e9956a6ca05a9be4db5a483446"><div class="ttname"><a href="namespacemlx_1_1core.html#a76f614e9956a6ca05a9be4db5a483446">mlx::core::get_arange_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_arange_kernel(metal::Device &d, const std::string &kernel_name, const array &out)</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a7aa91fcfe8b9caa42d60a957f11bfe6b"><div class="ttname"><a href="namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b">mlx::core::get_reduce_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_reduce_kernel(metal::Device &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const array &in, const array &out, int ndim=-1, int bm=-1, int bn=-1)</div></div>
|
||||
@ -332,7 +334,7 @@ $(function(){ initResizable(false); });
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a84fa8e0aee321a9d614433a0b933103b"><div class="ttname"><a href="namespacemlx_1_1core.html#a84fa8e0aee321a9d614433a0b933103b">mlx::core::get_steel_gemm_fused_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_steel_gemm_fused_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a90c24e0d0b99b68fad9deefcf4d3e818"><div class="ttname"><a href="namespacemlx_1_1core.html#a90c24e0d0b99b68fad9deefcf4d3e818">mlx::core::get_gemv_masked_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_gemv_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_mat, int bm, int bn, int sm, int sn, int tm, int tn, bool contiguous)</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_aa3faeae5378bfaafe3ce3432a051e43e"><div class="ttname"><a href="namespacemlx_1_1core.html#aa3faeae5378bfaafe3ce3432a051e43e">mlx::core::get_quantized_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_quantized_kernel(metal::Device &d, const std::string &kernel_name, const std::string &template_def)</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_aae0d19f0acdef2accd2428fb84c8a032"><div class="ttname"><a href="namespacemlx_1_1core.html#aae0d19f0acdef2accd2428fb84c8a032">mlx::core::get_template_definition</a></div><div class="ttdeci">std::string get_template_definition(std::string name, std::string func, Args... args)</div><div class="ttdef"><b>Definition</b> kernels.h:199</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_aae0d19f0acdef2accd2428fb84c8a032"><div class="ttname"><a href="namespacemlx_1_1core.html#aae0d19f0acdef2accd2428fb84c8a032">mlx::core::get_template_definition</a></div><div class="ttdeci">std::string get_template_definition(std::string name, std::string func, Args... args)</div><div class="ttdef"><b>Definition</b> kernels.h:201</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_ab5f60614e965144b451930fdf935e08d"><div class="ttname"><a href="namespacemlx_1_1core.html#ab5f60614e965144b451930fdf935e08d">mlx::core::get_steel_gemm_masked_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_steel_gemm_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_abce2b67044ee06a7bbe7a91ec7c8c48d"><div class="ttname"><a href="namespacemlx_1_1core.html#abce2b67044ee06a7bbe7a91ec7c8c48d">mlx::core::get_steel_conv_general_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_steel_conv_general_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn)</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_adce79d220672f5f3c65cc31d145ca9c4"><div class="ttname"><a href="namespacemlx_1_1core.html#adce79d220672f5f3c65cc31d145ca9c4">mlx::core::get_steel_conv_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_steel_conv_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn, int n_channel_specialization, bool small_filter)</div></div>
|
||||
|
2
docs/build/html/matmul_8h_source.html
vendored
2
docs/build/html/matmul_8h_source.html
vendored
@ -143,7 +143,7 @@ $(function(){ initResizable(false); });
|
||||
<div class="line"><a id="l00050" name="l00050"></a><span class="lineno"> 50</span>} <span class="comment">// namespace mlx::core</span></div>
|
||||
<div class="ttc" id="abackend_2metal_2device_8h_html"><div class="ttname"><a href="backend_2metal_2device_8h.html">device.h</a></div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1array_html"><div class="ttname"><a href="classmlx_1_1core_1_1array.html">mlx::core::array</a></div><div class="ttdef"><b>Definition</b> array.h:20</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html">mlx::core::metal::Device</a></div><div class="ttdef"><b>Definition</b> device.h:128</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html">mlx::core::metal::Device</a></div><div class="ttdef"><b>Definition</b> device.h:131</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html"><div class="ttname"><a href="namespacemlx_1_1core.html">mlx::core</a></div><div class="ttdef"><b>Definition</b> allocator.h:7</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a227588758ccc9ee869dba147e830bb74"><div class="ttname"><a href="namespacemlx_1_1core.html#a227588758ccc9ee869dba147e830bb74">mlx::core::steel_matmul_regular</a></div><div class="ttdeci">void steel_matmul_regular(const Stream &s, metal::Device &d, const array &a, const array &b, array &out, int M, int N, int K, int batch_size_out, int lda, int ldb, int ldd, bool transpose_a, bool transpose_b, std::vector< int > batch_shape, std::vector< size_t > batch_strides, size_t A_batch_stride, size_t B_batch_stride, size_t matrix_stride_out, std::vector< array > &copies)</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_ab43a7633794498e1c6775cca829eb886"><div class="ttname"><a href="namespacemlx_1_1core.html#ab43a7633794498e1c6775cca829eb886">mlx::core::steel_matmul</a></div><div class="ttdeci">void steel_matmul(const Stream &s, metal::Device &d, const array &a, const array &b, array &out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector< array > &copies, std::vector< int > batch_shape={}, std::vector< size_t > A_batch_stride={}, std::vector< size_t > B_batch_stride={})</div></div>
|
||||
|
4
docs/build/html/metal_2reduce_8h.html
vendored
4
docs/build/html/metal_2reduce_8h.html
vendored
@ -109,8 +109,8 @@ Namespaces</h2></td></tr>
|
||||
</table><table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a id="func-members" name="func-members"></a>
|
||||
Functions</h2></td></tr>
|
||||
<tr class="memitem:af7b7ca7c6aa87558d9f98cee5c7a99a8" id="r_af7b7ca7c6aa87558d9f98cee5c7a99a8"><td class="memItemLeft" align="right" valign="top">void </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacemlx_1_1core.html#af7b7ca7c6aa87558d9f98cee5c7a99a8">mlx::core::all_reduce_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out, const std::string &op_name, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &s, std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &copies)</td></tr>
|
||||
<tr class="separator:af7b7ca7c6aa87558d9f98cee5c7a99a8"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a3ab0fd997d9a35782106ff083a72e098" id="r_a3ab0fd997d9a35782106ff083a72e098"><td class="memItemLeft" align="right" valign="top">void </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacemlx_1_1core.html#a3ab0fd997d9a35782106ff083a72e098">mlx::core::all_reduce_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out, const std::string &op_name, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &s)</td></tr>
|
||||
<tr class="separator:a3ab0fd997d9a35782106ff083a72e098"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:ab1eeca8ec6fa31819ee108fa6ed2c41b" id="r_ab1eeca8ec6fa31819ee108fa6ed2c41b"><td class="memItemLeft" align="right" valign="top">void </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacemlx_1_1core.html#ab1eeca8ec6fa31819ee108fa6ed2c41b">mlx::core::row_reduce_general_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out, const std::string &op_name, const <a class="el" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a> &plan, const std::vector< int > &axes, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &s)</td></tr>
|
||||
<tr class="separator:ab1eeca8ec6fa31819ee108fa6ed2c41b"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:aa0332c64ee9965f05026c30a0b778000" id="r_aa0332c64ee9965f05026c30a0b778000"><td class="memItemLeft" align="right" valign="top">void </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacemlx_1_1core.html#aa0332c64ee9965f05026c30a0b778000">mlx::core::strided_reduce_general_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out, const std::string &op_name, const <a class="el" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a> &plan, const std::vector< int > &axes, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &s)</td></tr>
|
||||
|
53
docs/build/html/metal_2reduce_8h_source.html
vendored
53
docs/build/html/metal_2reduce_8h_source.html
vendored
@ -103,44 +103,43 @@ $(function(){ initResizable(false); });
|
||||
<div class="line"><a id="l00010" name="l00010"></a><span class="lineno"> 10</span> </div>
|
||||
<div class="line"><a id="l00011" name="l00011"></a><span class="lineno"> 11</span><span class="keyword">using </span>metal::CommandEncoder;</div>
|
||||
<div class="line"><a id="l00012" name="l00012"></a><span class="lineno"> 12</span> </div>
|
||||
<div class="line"><a id="l00013" name="l00013"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#af7b7ca7c6aa87558d9f98cee5c7a99a8"> 13</a></span><span class="keywordtype">void</span> <a class="code hl_function" href="namespacemlx_1_1core.html#af7b7ca7c6aa87558d9f98cee5c7a99a8">all_reduce_dispatch</a>(</div>
|
||||
<div class="line"><a id="l00013" name="l00013"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a3ab0fd997d9a35782106ff083a72e098"> 13</a></span><span class="keywordtype">void</span> <a class="code hl_function" href="namespacemlx_1_1core.html#a3ab0fd997d9a35782106ff083a72e098">all_reduce_dispatch</a>(</div>
|
||||
<div class="line"><a id="l00014" name="l00014"></a><span class="lineno"> 14</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& in,</div>
|
||||
<div class="line"><a id="l00015" name="l00015"></a><span class="lineno"> 15</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00016" name="l00016"></a><span class="lineno"> 16</span> <span class="keyword">const</span> std::string& op_name,</div>
|
||||
<div class="line"><a id="l00017" name="l00017"></a><span class="lineno"> 17</span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a>& compute_encoder,</div>
|
||||
<div class="line"><a id="l00018" name="l00018"></a><span class="lineno"> 18</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00019" name="l00019"></a><span class="lineno"> 19</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_stream.html">Stream</a>& s,</div>
|
||||
<div class="line"><a id="l00020" name="l00020"></a><span class="lineno"> 20</span> std::vector<array>& copies);</div>
|
||||
<div class="line"><a id="l00021" name="l00021"></a><span class="lineno"> 21</span> </div>
|
||||
<div class="line"><a id="l00022" name="l00022"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#ab1eeca8ec6fa31819ee108fa6ed2c41b"> 22</a></span><span class="keywordtype">void</span> <a class="code hl_function" href="namespacemlx_1_1core.html#ab1eeca8ec6fa31819ee108fa6ed2c41b">row_reduce_general_dispatch</a>(</div>
|
||||
<div class="line"><a id="l00023" name="l00023"></a><span class="lineno"> 23</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& in,</div>
|
||||
<div class="line"><a id="l00024" name="l00024"></a><span class="lineno"> 24</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00025" name="l00025"></a><span class="lineno"> 25</span> <span class="keyword">const</span> std::string& op_name,</div>
|
||||
<div class="line"><a id="l00026" name="l00026"></a><span class="lineno"> 26</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a>& plan,</div>
|
||||
<div class="line"><a id="l00027" name="l00027"></a><span class="lineno"> 27</span> <span class="keyword">const</span> std::vector<int>& axes,</div>
|
||||
<div class="line"><a id="l00028" name="l00028"></a><span class="lineno"> 28</span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a>& compute_encoder,</div>
|
||||
<div class="line"><a id="l00029" name="l00029"></a><span class="lineno"> 29</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00030" name="l00030"></a><span class="lineno"> 30</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_stream.html">Stream</a>& s);</div>
|
||||
<div class="line"><a id="l00031" name="l00031"></a><span class="lineno"> 31</span> </div>
|
||||
<div class="line"><a id="l00032" name="l00032"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#aa0332c64ee9965f05026c30a0b778000"> 32</a></span><span class="keywordtype">void</span> <a class="code hl_function" href="namespacemlx_1_1core.html#aa0332c64ee9965f05026c30a0b778000">strided_reduce_general_dispatch</a>(</div>
|
||||
<div class="line"><a id="l00033" name="l00033"></a><span class="lineno"> 33</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& in,</div>
|
||||
<div class="line"><a id="l00034" name="l00034"></a><span class="lineno"> 34</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00035" name="l00035"></a><span class="lineno"> 35</span> <span class="keyword">const</span> std::string& op_name,</div>
|
||||
<div class="line"><a id="l00036" name="l00036"></a><span class="lineno"> 36</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a>& plan,</div>
|
||||
<div class="line"><a id="l00037" name="l00037"></a><span class="lineno"> 37</span> <span class="keyword">const</span> std::vector<int>& axes,</div>
|
||||
<div class="line"><a id="l00038" name="l00038"></a><span class="lineno"> 38</span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a>& compute_encoder,</div>
|
||||
<div class="line"><a id="l00039" name="l00039"></a><span class="lineno"> 39</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00040" name="l00040"></a><span class="lineno"> 40</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_stream.html">Stream</a>& s);</div>
|
||||
<div class="line"><a id="l00041" name="l00041"></a><span class="lineno"> 41</span> </div>
|
||||
<div class="line"><a id="l00042" name="l00042"></a><span class="lineno"> 42</span>} <span class="comment">// namespace mlx::core</span></div>
|
||||
<div class="line"><a id="l00019" name="l00019"></a><span class="lineno"> 19</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_stream.html">Stream</a>& s);</div>
|
||||
<div class="line"><a id="l00020" name="l00020"></a><span class="lineno"> 20</span> </div>
|
||||
<div class="line"><a id="l00021" name="l00021"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#ab1eeca8ec6fa31819ee108fa6ed2c41b"> 21</a></span><span class="keywordtype">void</span> <a class="code hl_function" href="namespacemlx_1_1core.html#ab1eeca8ec6fa31819ee108fa6ed2c41b">row_reduce_general_dispatch</a>(</div>
|
||||
<div class="line"><a id="l00022" name="l00022"></a><span class="lineno"> 22</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& in,</div>
|
||||
<div class="line"><a id="l00023" name="l00023"></a><span class="lineno"> 23</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00024" name="l00024"></a><span class="lineno"> 24</span> <span class="keyword">const</span> std::string& op_name,</div>
|
||||
<div class="line"><a id="l00025" name="l00025"></a><span class="lineno"> 25</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a>& plan,</div>
|
||||
<div class="line"><a id="l00026" name="l00026"></a><span class="lineno"> 26</span> <span class="keyword">const</span> std::vector<int>& axes,</div>
|
||||
<div class="line"><a id="l00027" name="l00027"></a><span class="lineno"> 27</span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a>& compute_encoder,</div>
|
||||
<div class="line"><a id="l00028" name="l00028"></a><span class="lineno"> 28</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00029" name="l00029"></a><span class="lineno"> 29</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_stream.html">Stream</a>& s);</div>
|
||||
<div class="line"><a id="l00030" name="l00030"></a><span class="lineno"> 30</span> </div>
|
||||
<div class="line"><a id="l00031" name="l00031"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#aa0332c64ee9965f05026c30a0b778000"> 31</a></span><span class="keywordtype">void</span> <a class="code hl_function" href="namespacemlx_1_1core.html#aa0332c64ee9965f05026c30a0b778000">strided_reduce_general_dispatch</a>(</div>
|
||||
<div class="line"><a id="l00032" name="l00032"></a><span class="lineno"> 32</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& in,</div>
|
||||
<div class="line"><a id="l00033" name="l00033"></a><span class="lineno"> 33</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>& out,</div>
|
||||
<div class="line"><a id="l00034" name="l00034"></a><span class="lineno"> 34</span> <span class="keyword">const</span> std::string& op_name,</div>
|
||||
<div class="line"><a id="l00035" name="l00035"></a><span class="lineno"> 35</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a>& plan,</div>
|
||||
<div class="line"><a id="l00036" name="l00036"></a><span class="lineno"> 36</span> <span class="keyword">const</span> std::vector<int>& axes,</div>
|
||||
<div class="line"><a id="l00037" name="l00037"></a><span class="lineno"> 37</span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a>& compute_encoder,</div>
|
||||
<div class="line"><a id="l00038" name="l00038"></a><span class="lineno"> 38</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>& d,</div>
|
||||
<div class="line"><a id="l00039" name="l00039"></a><span class="lineno"> 39</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_stream.html">Stream</a>& s);</div>
|
||||
<div class="line"><a id="l00040" name="l00040"></a><span class="lineno"> 40</span> </div>
|
||||
<div class="line"><a id="l00041" name="l00041"></a><span class="lineno"> 41</span>} <span class="comment">// namespace mlx::core</span></div>
|
||||
<div class="ttc" id="abackend_2metal_2device_8h_html"><div class="ttname"><a href="backend_2metal_2device_8h.html">device.h</a></div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1array_html"><div class="ttname"><a href="classmlx_1_1core_1_1array.html">mlx::core::array</a></div><div class="ttdef"><b>Definition</b> array.h:20</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html">mlx::core::metal::Device</a></div><div class="ttdef"><b>Definition</b> device.h:128</div></div>
|
||||
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html">mlx::core::metal::Device</a></div><div class="ttdef"><b>Definition</b> device.h:131</div></div>
|
||||
<div class="ttc" id="acommon_2reduce_8h_html"><div class="ttname"><a href="common_2reduce_8h.html">reduce.h</a></div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html"><div class="ttname"><a href="namespacemlx_1_1core.html">mlx::core</a></div><div class="ttdef"><b>Definition</b> allocator.h:7</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_a3ab0fd997d9a35782106ff083a72e098"><div class="ttname"><a href="namespacemlx_1_1core.html#a3ab0fd997d9a35782106ff083a72e098">mlx::core::all_reduce_dispatch</a></div><div class="ttdeci">void all_reduce_dispatch(const array &in, array &out, const std::string &op_name, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s)</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_aa0332c64ee9965f05026c30a0b778000"><div class="ttname"><a href="namespacemlx_1_1core.html#aa0332c64ee9965f05026c30a0b778000">mlx::core::strided_reduce_general_dispatch</a></div><div class="ttdeci">void strided_reduce_general_dispatch(const array &in, array &out, const std::string &op_name, const ReductionPlan &plan, const std::vector< int > &axes, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s)</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_ab1eeca8ec6fa31819ee108fa6ed2c41b"><div class="ttname"><a href="namespacemlx_1_1core.html#ab1eeca8ec6fa31819ee108fa6ed2c41b">mlx::core::row_reduce_general_dispatch</a></div><div class="ttdeci">void row_reduce_general_dispatch(const array &in, array &out, const std::string &op_name, const ReductionPlan &plan, const std::vector< int > &axes, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s)</div></div>
|
||||
<div class="ttc" id="anamespacemlx_1_1core_html_af7b7ca7c6aa87558d9f98cee5c7a99a8"><div class="ttname"><a href="namespacemlx_1_1core.html#af7b7ca7c6aa87558d9f98cee5c7a99a8">mlx::core::all_reduce_dispatch</a></div><div class="ttdeci">void all_reduce_dispatch(const array &in, array &out, const std::string &op_name, CommandEncoder &compute_encoder, metal::Device &d, const Stream &s, std::vector< array > &copies)</div></div>
|
||||
<div class="ttc" id="astream_8h_html"><div class="ttname"><a href="stream_8h.html">stream.h</a></div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1_reduction_plan_html"><div class="ttname"><a href="structmlx_1_1core_1_1_reduction_plan.html">mlx::core::ReductionPlan</a></div><div class="ttdef"><b>Definition</b> reduce.h:39</div></div>
|
||||
<div class="ttc" id="astructmlx_1_1core_1_1_stream_html"><div class="ttname"><a href="structmlx_1_1core_1_1_stream.html">mlx::core::Stream</a></div><div class="ttdef"><b>Definition</b> stream.h:9</div></div>
|
||||
|
2
docs/build/html/namespacemembers.html
vendored
2
docs/build/html/namespacemembers.html
vendored
@ -99,7 +99,7 @@ $(function(){ initResizable(false); });
|
||||
<li>aligned_dealloc() : <a class="el" href="namespacepocketfft_1_1detail.html#aec7820e36a33e0a8bb83aa03b04b81e8">pocketfft::detail</a></li>
|
||||
<li>all() : <a class="el" href="group__ops.html#ga3b1b90ef1275ca17655b6d7f25d3ee68">mlx::core</a></li>
|
||||
<li>all_gather() : <a class="el" href="namespacemlx_1_1core_1_1distributed.html#a82ef5e8cc7ac62cd228e51b1c1b77cb7">mlx::core::distributed</a>, <a class="el" href="namespacemlx_1_1core_1_1distributed_1_1detail.html#aeb5a1726358213bc75756506f7b54d04">mlx::core::distributed::detail</a></li>
|
||||
<li>all_reduce_dispatch() : <a class="el" href="namespacemlx_1_1core.html#af7b7ca7c6aa87558d9f98cee5c7a99a8">mlx::core</a></li>
|
||||
<li>all_reduce_dispatch() : <a class="el" href="namespacemlx_1_1core.html#a3ab0fd997d9a35782106ff083a72e098">mlx::core</a></li>
|
||||
<li>all_sum() : <a class="el" href="namespacemlx_1_1core_1_1distributed.html#a67ccb1a5445fc6f5db49dd36a15e5980">mlx::core::distributed</a>, <a class="el" href="namespacemlx_1_1core_1_1distributed_1_1detail.html#aa1d225b25f7b6426c48c5e35860ee960">mlx::core::distributed::detail</a></li>
|
||||
<li>allclose() : <a class="el" href="group__ops.html#gaf0cd4257de7542daf9faf5e605e31020">mlx::core</a></li>
|
||||
<li>alloc_tmp() : <a class="el" href="namespacepocketfft_1_1detail.html#a4db03cbcd9d43d9e0b0b9067713c80e9">pocketfft::detail</a></li>
|
||||
|
2
docs/build/html/namespacemembers_func.html
vendored
2
docs/build/html/namespacemembers_func.html
vendored
@ -98,7 +98,7 @@ $(function(){ initResizable(false); });
|
||||
<li>aligned_dealloc() : <a class="el" href="namespacepocketfft_1_1detail.html#aec7820e36a33e0a8bb83aa03b04b81e8">pocketfft::detail</a></li>
|
||||
<li>all() : <a class="el" href="group__ops.html#ga3b1b90ef1275ca17655b6d7f25d3ee68">mlx::core</a></li>
|
||||
<li>all_gather() : <a class="el" href="namespacemlx_1_1core_1_1distributed.html#a82ef5e8cc7ac62cd228e51b1c1b77cb7">mlx::core::distributed</a>, <a class="el" href="namespacemlx_1_1core_1_1distributed_1_1detail.html#aeb5a1726358213bc75756506f7b54d04">mlx::core::distributed::detail</a></li>
|
||||
<li>all_reduce_dispatch() : <a class="el" href="namespacemlx_1_1core.html#af7b7ca7c6aa87558d9f98cee5c7a99a8">mlx::core</a></li>
|
||||
<li>all_reduce_dispatch() : <a class="el" href="namespacemlx_1_1core.html#a3ab0fd997d9a35782106ff083a72e098">mlx::core</a></li>
|
||||
<li>all_sum() : <a class="el" href="namespacemlx_1_1core_1_1distributed.html#a67ccb1a5445fc6f5db49dd36a15e5980">mlx::core::distributed</a>, <a class="el" href="namespacemlx_1_1core_1_1distributed_1_1detail.html#aa1d225b25f7b6426c48c5e35860ee960">mlx::core::distributed::detail</a></li>
|
||||
<li>allclose() : <a class="el" href="group__ops.html#gaf0cd4257de7542daf9faf5e605e31020">mlx::core</a></li>
|
||||
<li>alloc_tmp() : <a class="el" href="namespacepocketfft_1_1detail.html#a4db03cbcd9d43d9e0b0b9067713c80e9">pocketfft::detail</a></li>
|
||||
|
2
docs/build/html/namespacemembers_func_g.html
vendored
2
docs/build/html/namespacemembers_func_g.html
vendored
@ -112,7 +112,7 @@ $(function(){ initResizable(false); });
|
||||
<li>get_pool() : <a class="el" href="namespacepocketfft_1_1detail_1_1threading.html#a7ec2b3f99232bd0f15f7b022c59d139a">pocketfft::detail::threading</a></li>
|
||||
<li>get_primitive_string() : <a class="el" href="namespacemlx_1_1core.html#ad4be35b310a252edd80d9cf04f094a60">mlx::core</a></li>
|
||||
<li>get_quantized_kernel() : <a class="el" href="namespacemlx_1_1core.html#aa3faeae5378bfaafe3ce3432a051e43e">mlx::core</a></li>
|
||||
<li>get_reduce_init_kernel() : <a class="el" href="namespacemlx_1_1core.html#a51c4bb09230348bd0252e22bfdc9bc89">mlx::core</a></li>
|
||||
<li>get_reduce_init_kernel() : <a class="el" href="namespacemlx_1_1core.html#a3bd386cb6db09f636963ce66ceaf8647">mlx::core</a></li>
|
||||
<li>get_reduce_kernel() : <a class="el" href="namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b">mlx::core</a></li>
|
||||
<li>get_reduction_plan() : <a class="el" href="namespacemlx_1_1core.html#ac97b5a6f009ca3d99854ce9512c20dba">mlx::core</a></li>
|
||||
<li>get_scan_kernel() : <a class="el" href="namespacemlx_1_1core.html#aeefaff208444d3fa61ecc0946fe1de5f">mlx::core</a></li>
|
||||
|
2
docs/build/html/namespacemembers_g.html
vendored
2
docs/build/html/namespacemembers_g.html
vendored
@ -116,7 +116,7 @@ $(function(){ initResizable(false); });
|
||||
<li>get_pool() : <a class="el" href="namespacepocketfft_1_1detail_1_1threading.html#a7ec2b3f99232bd0f15f7b022c59d139a">pocketfft::detail::threading</a></li>
|
||||
<li>get_primitive_string() : <a class="el" href="namespacemlx_1_1core.html#ad4be35b310a252edd80d9cf04f094a60">mlx::core</a></li>
|
||||
<li>get_quantized_kernel() : <a class="el" href="namespacemlx_1_1core.html#aa3faeae5378bfaafe3ce3432a051e43e">mlx::core</a></li>
|
||||
<li>get_reduce_init_kernel() : <a class="el" href="namespacemlx_1_1core.html#a51c4bb09230348bd0252e22bfdc9bc89">mlx::core</a></li>
|
||||
<li>get_reduce_init_kernel() : <a class="el" href="namespacemlx_1_1core.html#a3bd386cb6db09f636963ce66ceaf8647">mlx::core</a></li>
|
||||
<li>get_reduce_kernel() : <a class="el" href="namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b">mlx::core</a></li>
|
||||
<li>get_reduction_plan() : <a class="el" href="namespacemlx_1_1core.html#ac97b5a6f009ca3d99854ce9512c20dba">mlx::core</a></li>
|
||||
<li>get_scan_kernel() : <a class="el" href="namespacemlx_1_1core.html#aeefaff208444d3fa61ecc0946fe1de5f">mlx::core</a></li>
|
||||
|
33
docs/build/html/namespacemlx_1_1core.html
vendored
33
docs/build/html/namespacemlx_1_1core.html
vendored
@ -534,8 +534,8 @@ Functions</h2></td></tr>
|
||||
<tr class="separator:a84ebe6275218070f0ea320f126f64e22"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:afb57825bb763050cc9a9d194aa41ac36" id="r_afb57825bb763050cc9a9d194aa41ac36"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#afb57825bb763050cc9a9d194aa41ac36">get_mb_sort_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &in, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &idx, int bn, int tn)</td></tr>
|
||||
<tr class="separator:afb57825bb763050cc9a9d194aa41ac36"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a51c4bb09230348bd0252e22bfdc9bc89" id="r_a51c4bb09230348bd0252e22bfdc9bc89"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#a51c4bb09230348bd0252e22bfdc9bc89">get_reduce_init_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out)</td></tr>
|
||||
<tr class="separator:a51c4bb09230348bd0252e22bfdc9bc89"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a3bd386cb6db09f636963ce66ceaf8647" id="r_a3bd386cb6db09f636963ce66ceaf8647"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#a3bd386cb6db09f636963ce66ceaf8647">get_reduce_init_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out)</td></tr>
|
||||
<tr class="separator:a3bd386cb6db09f636963ce66ceaf8647"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a7aa91fcfe8b9caa42d60a957f11bfe6b" id="r_a7aa91fcfe8b9caa42d60a957f11bfe6b"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#a7aa91fcfe8b9caa42d60a957f11bfe6b">get_reduce_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &in, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out, int ndim=-1, int bm=-1, int bn=-1)</td></tr>
|
||||
<tr class="separator:a7aa91fcfe8b9caa42d60a957f11bfe6b"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a84fa8e0aee321a9d614433a0b933103b" id="r_a84fa8e0aee321a9d614433a0b933103b"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState * </td><td class="memItemRight" valign="bottom"><a class="el" href="#a84fa8e0aee321a9d614433a0b933103b">get_steel_gemm_fused_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const std::string &kernel_name, const std::string &hash_name, const <a class="el" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">metal::MTLFCList</a> &func_consts, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)</td></tr>
|
||||
@ -563,8 +563,8 @@ Functions</h2></td></tr>
|
||||
<tr class="separator:a227588758ccc9ee869dba147e830bb74"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:ab43a7633794498e1c6775cca829eb886" id="r_ab43a7633794498e1c6775cca829eb886"><td class="memItemLeft" align="right" valign="top">void </td><td class="memItemRight" valign="bottom"><a class="el" href="#ab43a7633794498e1c6775cca829eb886">steel_matmul</a> (const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &s, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &a, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &b, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &copies, std::vector< int > batch_shape={}, std::vector< size_t > A_batch_stride={}, std::vector< size_t > B_batch_stride={})</td></tr>
|
||||
<tr class="separator:ab43a7633794498e1c6775cca829eb886"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:af7b7ca7c6aa87558d9f98cee5c7a99a8" id="r_af7b7ca7c6aa87558d9f98cee5c7a99a8"><td class="memItemLeft" align="right" valign="top">void </td><td class="memItemRight" valign="bottom"><a class="el" href="#af7b7ca7c6aa87558d9f98cee5c7a99a8">all_reduce_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out, const std::string &op_name, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &s, std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &copies)</td></tr>
|
||||
<tr class="separator:af7b7ca7c6aa87558d9f98cee5c7a99a8"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a3ab0fd997d9a35782106ff083a72e098" id="r_a3ab0fd997d9a35782106ff083a72e098"><td class="memItemLeft" align="right" valign="top">void </td><td class="memItemRight" valign="bottom"><a class="el" href="#a3ab0fd997d9a35782106ff083a72e098">all_reduce_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out, const std::string &op_name, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &s)</td></tr>
|
||||
<tr class="separator:a3ab0fd997d9a35782106ff083a72e098"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:ab1eeca8ec6fa31819ee108fa6ed2c41b" id="r_ab1eeca8ec6fa31819ee108fa6ed2c41b"><td class="memItemLeft" align="right" valign="top">void </td><td class="memItemRight" valign="bottom"><a class="el" href="#ab1eeca8ec6fa31819ee108fa6ed2c41b">row_reduce_general_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out, const std::string &op_name, const <a class="el" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a> &plan, const std::vector< int > &axes, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &s)</td></tr>
|
||||
<tr class="separator:ab1eeca8ec6fa31819ee108fa6ed2c41b"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:aa0332c64ee9965f05026c30a0b778000" id="r_aa0332c64ee9965f05026c30a0b778000"><td class="memItemLeft" align="right" valign="top">void </td><td class="memItemRight" valign="bottom"><a class="el" href="#aa0332c64ee9965f05026c30a0b778000">strided_reduce_general_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &out, const std::string &op_name, const <a class="el" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a> &plan, const std::vector< int > &axes, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &s)</td></tr>
|
||||
@ -2634,8 +2634,8 @@ template<typename... T> </div>
|
||||
</div>
|
||||
</div>
|
||||
<h2 class="groupheader">Function Documentation</h2>
|
||||
<a id="af7b7ca7c6aa87558d9f98cee5c7a99a8" name="af7b7ca7c6aa87558d9f98cee5c7a99a8"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#af7b7ca7c6aa87558d9f98cee5c7a99a8">◆ </a></span>all_reduce_dispatch()</h2>
|
||||
<a id="a3ab0fd997d9a35782106ff083a72e098" name="a3ab0fd997d9a35782106ff083a72e098"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a3ab0fd997d9a35782106ff083a72e098">◆ </a></span>all_reduce_dispatch()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
@ -2668,12 +2668,7 @@ template<typename... T> </div>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &</td> <td class="paramname"><span class="paramname"><em>s</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">std::vector< <a class="el" href="classmlx_1_1core_1_1array.html">array</a> > &</td> <td class="paramname"><span class="paramname"><em>copies</em></span> )</td>
|
||||
<td class="paramtype">const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &</td> <td class="paramname"><span class="paramname"><em>s</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
@ -4418,8 +4413,8 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a51c4bb09230348bd0252e22bfdc9bc89" name="a51c4bb09230348bd0252e22bfdc9bc89"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a51c4bb09230348bd0252e22bfdc9bc89">◆ </a></span>get_reduce_init_kernel()</h2>
|
||||
<a id="a3bd386cb6db09f636963ce66ceaf8647" name="a3bd386cb6db09f636963ce66ceaf8647"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a3bd386cb6db09f636963ce66ceaf8647">◆ </a></span>get_reduce_init_kernel()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
@ -4434,6 +4429,16 @@ template<typename... Arrays, typename = enable_for_arrays_t<Arrays...>
|
||||
<td></td>
|
||||
<td class="paramtype">const std::string &</td> <td class="paramname"><span class="paramname"><em>kernel_name</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const std::string &</td> <td class="paramname"><span class="paramname"><em>func_name</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const std::string &</td> <td class="paramname"><span class="paramname"><em>op_name</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
|
BIN
docs/build/html/objects.inv
vendored
BIN
docs/build/html/objects.inv
vendored
Binary file not shown.
@ -867,6 +867,7 @@
|
||||
<dt class="sig sig-object py" id="mlx.core.fast.metal_kernel">
|
||||
<span class="sig-name descname"><span class="pre">metal_kernel</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">name</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><span class="pre">str</span></a></span></em>, <em class="sig-param"><span class="n"><span class="pre">input_names</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.13)"><span class="pre">Sequence</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><span class="pre">str</span></a><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">output_names</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.13)"><span class="pre">Sequence</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><span class="pre">str</span></a><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">source</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><span class="pre">str</span></a></span></em>, <em class="sig-param"><span class="n"><span class="pre">header</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><span class="pre">str</span></a></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">''</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">ensure_row_contiguous</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><span class="pre">bool</span></a></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">atomic_outputs</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><span class="pre">bool</span></a></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">→</span> <span class="sig-return-typehint"><a class="reference external" href="https://docs.python.org/3/library/functions.html#object" title="(in Python v3.13)"><span class="pre">object</span></a></span></span><a class="headerlink" href="#mlx.core.fast.metal_kernel" title="Link to this definition">#</a></dt>
|
||||
<dd><p>A jit-compiled custom Metal kernel defined from a source string.</p>
|
||||
<p>Full documentation: <a class="reference internal" href="../../dev/custom_metal_kernels.html#custom-metal-kernels"><span class="std std-ref">Custom Metal Kernels</span></a>.</p>
|
||||
<dl class="field-list simple">
|
||||
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
|
||||
<dd class="field-odd"><ul class="simple">
|
||||
|
121
docs/build/html/quantized_8h.html
vendored
121
docs/build/html/quantized_8h.html
vendored
@ -140,9 +140,9 @@ Functions</h2></td></tr>
|
||||
<tr class="memitem:a8e13c7d895624f738d2a6d9893b687fd" id="r_a8e13c7d895624f738d2a6d9893b687fd"><td class="memTemplParams" colspan="2">template<typename T , int group_size, int bits> </td></tr>
|
||||
<tr class="memitem:a8e13c7d895624f738d2a6d9893b687fd"><td class="memTemplItemLeft" align="right" valign="top">METAL_FUNC void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a8e13c7d895624f738d2a6d9893b687fd">qmv_impl</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:a8e13c7d895624f738d2a6d9893b687fd"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a4a8c8db7d5d480733726fd6d1a645e12" id="r_a4a8c8db7d5d480733726fd6d1a645e12"><td class="memTemplParams" colspan="2">template<typename T , const int group_size, const int bits> </td></tr>
|
||||
<tr class="memitem:a4a8c8db7d5d480733726fd6d1a645e12"><td class="memTemplItemLeft" align="right" valign="top">METAL_FUNC void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a4a8c8db7d5d480733726fd6d1a645e12">qvm_impl</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:a4a8c8db7d5d480733726fd6d1a645e12"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a1546533c5b925b2fbb3bec870ec7487a" id="r_a1546533c5b925b2fbb3bec870ec7487a"><td class="memTemplParams" colspan="2">template<typename T , const int group_size, const int bits> </td></tr>
|
||||
<tr class="memitem:a1546533c5b925b2fbb3bec870ec7487a"><td class="memTemplItemLeft" align="right" valign="top">METAL_FUNC void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a1546533c5b925b2fbb3bec870ec7487a">qvm_impl</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const int in_vec_size, const int out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:a1546533c5b925b2fbb3bec870ec7487a"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:af5750a35e8f5462218effba719f7f5b8" id="r_af5750a35e8f5462218effba719f7f5b8"><td class="memTemplParams" colspan="2">template<typename T , const int group_size, const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32> </td></tr>
|
||||
<tr class="memitem:af5750a35e8f5462218effba719f7f5b8"><td class="memTemplItemLeft" align="right" valign="top">METAL_FUNC void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#af5750a35e8f5462218effba719f7f5b8">qmm_t_impl</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:af5750a35e8f5462218effba719f7f5b8"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
@ -167,6 +167,9 @@ Functions</h2></td></tr>
|
||||
<tr class="memitem:ad84f7d5ab9e32dbbe3ca759ae5d5d5c5" id="r_ad84f7d5ab9e32dbbe3ca759ae5d5d5c5"><td class="memTemplParams" colspan="2">template<typename T , const int group_size, const int bits, bool batched> </td></tr>
|
||||
<tr class="memitem:ad84f7d5ab9e32dbbe3ca759ae5d5d5c5"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#ad84f7d5ab9e32dbbe3ca759ae5d5d5c5">qvm</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:ad84f7d5ab9e32dbbe3ca759ae5d5d5c5"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:ab8243818512d6078d23e6ffb65fd7bb8" id="r_ab8243818512d6078d23e6ffb65fd7bb8"><td class="memTemplParams" colspan="2">template<typename T , const int group_size, const int bits, int split_k = 32> </td></tr>
|
||||
<tr class="memitem:ab8243818512d6078d23e6ffb65fd7bb8"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#ab8243818512d6078d23e6ffb65fd7bb8">qvm_split_k</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &final_block_size, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:ab8243818512d6078d23e6ffb65fd7bb8"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:abe2e3ef0ee4ec2cb61dc5330ad463d10" id="r_abe2e3ef0ee4ec2cb61dc5330ad463d10"><td class="memTemplParams" colspan="2">template<typename T , const int group_size, const int bits, const bool aligned_N, const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> </td></tr>
|
||||
<tr class="memitem:abe2e3ef0ee4ec2cb61dc5330ad463d10"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#abe2e3ef0ee4ec2cb61dc5330ad463d10">qmm_t</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:abe2e3ef0ee4ec2cb61dc5330ad463d10"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
@ -2485,8 +2488,8 @@ template<typename T , const int group_size, const int bits, bool batched>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a4a8c8db7d5d480733726fd6d1a645e12" name="a4a8c8db7d5d480733726fd6d1a645e12"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a4a8c8db7d5d480733726fd6d1a645e12">◆ </a></span>qvm_impl()</h2>
|
||||
<a id="a1546533c5b925b2fbb3bec870ec7487a" name="a1546533c5b925b2fbb3bec870ec7487a"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a1546533c5b925b2fbb3bec870ec7487a">◆ </a></span>qvm_impl()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
@ -2518,6 +2521,69 @@ template<typename T , const int group_size, const int bits> </div>
|
||||
<td></td>
|
||||
<td class="paramtype">device T *</td> <td class="paramname"><span class="paramname"><em>y</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const int</td> <td class="paramname"><span class="paramname"><em>in_vec_size</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const int</td> <td class="paramname"><span class="paramname"><em>out_vec_size</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>tid</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_gid</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_lid</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="ab8243818512d6078d23e6ffb65fd7bb8" name="ab8243818512d6078d23e6ffb65fd7bb8"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#ab8243818512d6078d23e6ffb65fd7bb8">◆ </a></span>qvm_split_k()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<typename T , const int group_size, const int bits, int split_k = 32> </div>
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">void qvm_split_k </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">const device uint32_t *</td> <td class="paramname"><span class="paramname"><em>w</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>scales</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>biases</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>x</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">device T *</td> <td class="paramname"><span class="paramname"><em>y</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
@ -2528,6 +2594,51 @@ template<typename T , const int group_size, const int bits> </div>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>out_vec_size</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>x_batch_ndims</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>x_shape</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>x_strides</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>w_batch_ndims</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>w_shape</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>w_strides</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>s_strides</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>b_strides</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>final_block_size</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
|
1274
docs/build/html/quantized_8h_source.html
vendored
1274
docs/build/html/quantized_8h_source.html
vendored
File diff suppressed because it is too large
Load Diff
218
docs/build/html/reduce__col_8h.html
vendored
218
docs/build/html/reduce__col_8h.html
vendored
@ -98,15 +98,207 @@ $(function(){ initResizable(false); });
|
||||
<table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a id="func-members" name="func-members"></a>
|
||||
Functions</h2></td></tr>
|
||||
<tr class="memitem:adf7aeb18cd1d5042cf6d9b46b582d8ce" id="r_adf7aeb18cd1d5042cf6d9b46b582d8ce"><td class="memTemplParams" colspan="2">template<typename T , typename U , typename Op , int NDIMS, int N_READS = REDUCE_N_READS> </td></tr>
|
||||
<tr class="memitem:adf7aeb18cd1d5042cf6d9b46b582d8ce"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#adf7aeb18cd1d5042cf6d9b46b582d8ce">col_reduce_small</a> (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 tsize)</td></tr>
|
||||
<tr class="separator:adf7aeb18cd1d5042cf6d9b46b582d8ce"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a7c378443a2b6f4d9210db8a21a9ac4f5" id="r_a7c378443a2b6f4d9210db8a21a9ac4f5"><td class="memTemplParams" colspan="2">template<typename T , typename U , typename Op , int NDIMS> </td></tr>
|
||||
<tr class="memitem:a7c378443a2b6f4d9210db8a21a9ac4f5"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a7c378443a2b6f4d9210db8a21a9ac4f5">col_reduce_small</a> (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)</td></tr>
|
||||
<tr class="separator:a7c378443a2b6f4d9210db8a21a9ac4f5"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a5b4f4c4c247ad341ff8d31dcbbbce0eb" id="r_a5b4f4c4c247ad341ff8d31dcbbbce0eb"><td class="memTemplParams" colspan="2">template<typename T , typename U , typename Op , int NDIMS> </td></tr>
|
||||
<tr class="memitem:a5b4f4c4c247ad341ff8d31dcbbbce0eb"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a5b4f4c4c247ad341ff8d31dcbbbce0eb">col_reduce_longcolumn</a> (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, const constant size_t &out_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)</td></tr>
|
||||
<tr class="separator:a5b4f4c4c247ad341ff8d31dcbbbce0eb"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a11bfc6112ae2386ac03f5ea7b7d93385" id="r_a11bfc6112ae2386ac03f5ea7b7d93385"><td class="memTemplParams" colspan="2">template<typename T , typename U , typename Op , int NDIMS, int BM, int BN> </td></tr>
|
||||
<tr class="memitem:a11bfc6112ae2386ac03f5ea7b7d93385"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a11bfc6112ae2386ac03f5ea7b7d93385">col_reduce_looped</a> (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)</td></tr>
|
||||
<tr class="memdesc:a11bfc6112ae2386ac03f5ea7b7d93385"><td class="mdescLeft"> </td><td class="mdescRight">Our approach is the following simple looped approach: <br /></td></tr>
|
||||
<tr class="separator:a11bfc6112ae2386ac03f5ea7b7d93385"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a0e92fc74eeaa8ee2ceb83bafc6eb1d7d" id="r_a0e92fc74eeaa8ee2ceb83bafc6eb1d7d"><td class="memTemplParams" colspan="2">template<typename T , typename U , typename Op , int NDIMS, int BM, int BN> </td></tr>
|
||||
<tr class="memitem:a0e92fc74eeaa8ee2ceb83bafc6eb1d7d"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d">col_reduce_2pass</a> (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, const constant size_t &out_size, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)</td></tr>
|
||||
<tr class="separator:a0e92fc74eeaa8ee2ceb83bafc6eb1d7d"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table>
|
||||
<h2 class="groupheader">Function Documentation</h2>
|
||||
<a id="a0e92fc74eeaa8ee2ceb83bafc6eb1d7d" name="a0e92fc74eeaa8ee2ceb83bafc6eb1d7d"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d">◆ </a></span>col_reduce_2pass()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<typename T , typename U , typename Op , int NDIMS, int BM, int BN> </div>
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">void col_reduce_2pass </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>in</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">device U *</td> <td class="paramname"><span class="paramname"><em>out</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>reduction_size</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>reduction_stride</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>shape</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>strides</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>ndim</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>reduce_shape</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>reduce_strides</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>reduce_ndim</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>non_col_reductions</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>out_size</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>gid</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>gsize</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_lane_id</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_group_id</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a5b4f4c4c247ad341ff8d31dcbbbce0eb" name="a5b4f4c4c247ad341ff8d31dcbbbce0eb"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a5b4f4c4c247ad341ff8d31dcbbbce0eb">◆ </a></span>col_reduce_longcolumn()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<typename T , typename U , typename Op , int NDIMS> </div>
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">void col_reduce_longcolumn </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>in</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">device U *</td> <td class="paramname"><span class="paramname"><em>out</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>reduction_size</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>reduction_stride</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>shape</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>strides</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>ndim</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>reduce_shape</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>reduce_strides</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>reduce_ndim</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>non_col_reductions</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>out_size</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>gid</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>gsize</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>lid</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>lsize</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a11bfc6112ae2386ac03f5ea7b7d93385" name="a11bfc6112ae2386ac03f5ea7b7d93385"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a11bfc6112ae2386ac03f5ea7b7d93385">◆ </a></span>col_reduce_looped()</h2>
|
||||
|
||||
@ -204,13 +396,13 @@ template<typename T , typename U , typename Op , int NDIMS, int BM, int BN>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="adf7aeb18cd1d5042cf6d9b46b582d8ce" name="adf7aeb18cd1d5042cf6d9b46b582d8ce"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#adf7aeb18cd1d5042cf6d9b46b582d8ce">◆ </a></span>col_reduce_small()</h2>
|
||||
<a id="a7c378443a2b6f4d9210db8a21a9ac4f5" name="a7c378443a2b6f4d9210db8a21a9ac4f5"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a7c378443a2b6f4d9210db8a21a9ac4f5">◆ </a></span>col_reduce_small()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<typename T , typename U , typename Op , int NDIMS, int N_READS = REDUCE_N_READS> </div>
|
||||
template<typename T , typename U , typename Op , int NDIMS> </div>
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">void col_reduce_small </td>
|
||||
@ -280,22 +472,12 @@ template<typename T , typename U , typename Op , int NDIMS, int N_READS = RED
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_lane_id</em></span>, </td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>lid</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_group_id</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>tid</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>tsize</em></span> )</td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>lsize</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
692
docs/build/html/reduce__col_8h_source.html
vendored
692
docs/build/html/reduce__col_8h_source.html
vendored
@ -93,334 +93,392 @@ $(function(){ initResizable(false); });
|
||||
<div class="contents">
|
||||
<a href="reduce__col_8h.html">Go to the documentation of this file.</a><div class="fragment"><div class="line"><a id="l00001" name="l00001"></a><span class="lineno"> 1</span><span class="comment">// Copyright © 2023-2024 Apple Inc.</span></div>
|
||||
<div class="line"><a id="l00002" name="l00002"></a><span class="lineno"> 2</span> </div>
|
||||
<div class="line"><a id="l00003" name="l00003"></a><span class="lineno"> 3</span><span class="keyword">template</span> <</div>
|
||||
<div class="line"><a id="l00004" name="l00004"></a><span class="lineno"> 4</span> <span class="keyword">typename</span> T,</div>
|
||||
<div class="line"><a id="l00005" name="l00005"></a><span class="lineno"> 5</span> <span class="keyword">typename</span> U,</div>
|
||||
<div class="line"><a id="l00006" name="l00006"></a><span class="lineno"> 6</span> <span class="keyword">typename</span> Op,</div>
|
||||
<div class="line"><a id="l00007" name="l00007"></a><span class="lineno"> 7</span> <span class="keywordtype">int</span> NDIMS,</div>
|
||||
<div class="line"><a id="l00008" name="l00008"></a><span class="lineno"> 8</span> <span class="keywordtype">int</span> N_READS = <a class="code hl_variable" href="defines_8h.html#a2ad505864a2ab786147766900bc18c21">REDUCE_N_READS</a>></div>
|
||||
<div class="foldopen" id="foldopen00009" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00009" name="l00009"></a><span class="lineno"><a class="line" href="reduce__col_8h.html#adf7aeb18cd1d5042cf6d9b46b582d8ce"> 9</a></span>[[kernel]] <span class="keywordtype">void</span> <a class="code hl_function" href="reduce__col_8h.html#adf7aeb18cd1d5042cf6d9b46b582d8ce">col_reduce_small</a>(</div>
|
||||
<div class="line"><a id="l00010" name="l00010"></a><span class="lineno"> 10</span> <span class="keyword">const</span> device T* in [[buffer(0)]],</div>
|
||||
<div class="line"><a id="l00011" name="l00011"></a><span class="lineno"> 11</span> device U* out [[buffer(1)]],</div>
|
||||
<div class="line"><a id="l00012" name="l00012"></a><span class="lineno"> 12</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& reduction_size [[buffer(2)]],</div>
|
||||
<div class="line"><a id="l00013" name="l00013"></a><span class="lineno"> 13</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& reduction_stride [[buffer(3)]],</div>
|
||||
<div class="line"><a id="l00014" name="l00014"></a><span class="lineno"> 14</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* shape [[buffer(4)]],</div>
|
||||
<div class="line"><a id="l00015" name="l00015"></a><span class="lineno"> 15</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* strides [[buffer(5)]],</div>
|
||||
<div class="line"><a id="l00016" name="l00016"></a><span class="lineno"> 16</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& ndim [[buffer(6)]],</div>
|
||||
<div class="line"><a id="l00017" name="l00017"></a><span class="lineno"> 17</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* reduce_shape [[buffer(7)]],</div>
|
||||
<div class="line"><a id="l00018" name="l00018"></a><span class="lineno"> 18</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* reduce_strides [[buffer(8)]],</div>
|
||||
<div class="line"><a id="l00019" name="l00019"></a><span class="lineno"> 19</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& reduce_ndim [[buffer(9)]],</div>
|
||||
<div class="line"><a id="l00020" name="l00020"></a><span class="lineno"> 20</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& non_col_reductions [[buffer(10)]],</div>
|
||||
<div class="line"><a id="l00021" name="l00021"></a><span class="lineno"> 21</span> uint3 gid [[threadgroup_position_in_grid]],</div>
|
||||
<div class="line"><a id="l00022" name="l00022"></a><span class="lineno"> 22</span> uint3 gsize [[threadgroups_per_grid]],</div>
|
||||
<div class="line"><a id="l00023" name="l00023"></a><span class="lineno"> 23</span> uint simd_lane_id [[thread_index_in_simdgroup]],</div>
|
||||
<div class="line"><a id="l00024" name="l00024"></a><span class="lineno"> 24</span> uint simd_group_id [[simdgroup_index_in_threadgroup]],</div>
|
||||
<div class="line"><a id="l00025" name="l00025"></a><span class="lineno"> 25</span> uint3 tid [[thread_position_in_grid]],</div>
|
||||
<div class="line"><a id="l00026" name="l00026"></a><span class="lineno"> 26</span> uint3 tsize [[threads_per_grid]]) {</div>
|
||||
<div class="line"><a id="l00027" name="l00027"></a><span class="lineno"> 27</span> Op <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>;</div>
|
||||
<div class="line"><a id="l00028" name="l00028"></a><span class="lineno"> 28</span> <a class="code hl_struct" href="structlooped__elem__to__loc.html">looped_elem_to_loc<NDIMS></a> loop;</div>
|
||||
<div class="line"><a id="l00029" name="l00029"></a><span class="lineno"> 29</span> <span class="keyword">const</span> device T* row;</div>
|
||||
<div class="line"><a id="l00030" name="l00030"></a><span class="lineno"> 30</span> </div>
|
||||
<div class="line"><a id="l00031" name="l00031"></a><span class="lineno"> 31</span> <span class="comment">// Case 1: Small row small column</span></div>
|
||||
<div class="line"><a id="l00032" name="l00032"></a><span class="lineno"> 32</span> <span class="keywordflow">if</span> (reduction_size * non_col_reductions < 64 && reduction_stride < 32) {</div>
|
||||
<div class="line"><a id="l00033" name="l00033"></a><span class="lineno"> 33</span> U totals[31];</div>
|
||||
<div class="line"><a id="l00034" name="l00034"></a><span class="lineno"> 34</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < 31; i++) {</div>
|
||||
<div class="line"><a id="l00035" name="l00035"></a><span class="lineno"> 35</span> totals[i] = Op::init;</div>
|
||||
<div class="line"><a id="l00036" name="l00036"></a><span class="lineno"> 36</span> }</div>
|
||||
<div class="line"><a id="l00037" name="l00037"></a><span class="lineno"> 37</span> </div>
|
||||
<div class="line"><a id="l00038" name="l00038"></a><span class="lineno"> 38</span> <span class="keywordtype">short</span> stride = reduction_stride;</div>
|
||||
<div class="line"><a id="l00039" name="l00039"></a><span class="lineno"> 39</span> <span class="keywordtype">short</span> size = reduction_size;</div>
|
||||
<div class="line"><a id="l00040" name="l00040"></a><span class="lineno"> 40</span> <span class="keywordtype">short</span> blocks = stride / N_READS;</div>
|
||||
<div class="line"><a id="l00041" name="l00041"></a><span class="lineno"> 41</span> <span class="keywordtype">short</span> extra = stride - blocks * N_READS;</div>
|
||||
<div class="line"><a id="l00042" name="l00042"></a><span class="lineno"> 42</span> </div>
|
||||
<div class="line"><a id="l00043" name="l00043"></a><span class="lineno"> 43</span> <span class="keywordtype">size_t</span> out_idx = tid.x + tsize.y * size_t(tid.y);</div>
|
||||
<div class="line"><a id="l00044" name="l00044"></a><span class="lineno"> 44</span> in += <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a8fd0c8fc6058e650fc99bca8b6acd7d1">elem_to_loc</a>(out_idx, shape, strides, ndim);</div>
|
||||
<div class="line"><a id="l00045" name="l00045"></a><span class="lineno"> 45</span> </div>
|
||||
<div class="line"><a id="l00046" name="l00046"></a><span class="lineno"> 46</span> <span class="keywordflow">for</span> (uint r = 0; r < non_col_reductions; r++) {</div>
|
||||
<div class="line"><a id="l00047" name="l00047"></a><span class="lineno"> 47</span> row = in + loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#accc6d4957a8aeb38f5062754793b74d2">location</a>(r, reduce_shape, reduce_strides, reduce_ndim);</div>
|
||||
<div class="line"><a id="l00048" name="l00048"></a><span class="lineno"> 48</span> </div>
|
||||
<div class="line"><a id="l00049" name="l00049"></a><span class="lineno"> 49</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> i = 0; i < size; i++) {</div>
|
||||
<div class="line"><a id="l00050" name="l00050"></a><span class="lineno"> 50</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> j = 0; j < blocks; j++) {</div>
|
||||
<div class="line"><a id="l00051" name="l00051"></a><span class="lineno"> 51</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> k = 0; k < N_READS; k++) {</div>
|
||||
<div class="line"><a id="l00052" name="l00052"></a><span class="lineno"> 52</span> totals[j * N_READS + k] =</div>
|
||||
<div class="line"><a id="l00053" name="l00053"></a><span class="lineno"> 53</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(totals[j * N_READS + k],</div>
|
||||
<div class="line"><a id="l00054" name="l00054"></a><span class="lineno"> 54</span> <span class="keyword">static_cast<</span>U<span class="keyword">></span>(row[i * stride + j * N_READS + k]));</div>
|
||||
<div class="line"><a id="l00055" name="l00055"></a><span class="lineno"> 55</span> }</div>
|
||||
<div class="line"><a id="l00003" name="l00003"></a><span class="lineno"> 3</span><span class="keyword">template</span> <<span class="keyword">typename</span> T, <span class="keyword">typename</span> U, <span class="keyword">typename</span> Op, <span class="keywordtype">int</span> NDIMS></div>
|
||||
<div class="foldopen" id="foldopen00004" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00004" name="l00004"></a><span class="lineno"><a class="line" href="reduce__col_8h.html#a7c378443a2b6f4d9210db8a21a9ac4f5"> 4</a></span>[[kernel]] <span class="keywordtype">void</span> <a class="code hl_function" href="reduce__col_8h.html#a7c378443a2b6f4d9210db8a21a9ac4f5">col_reduce_small</a>(</div>
|
||||
<div class="line"><a id="l00005" name="l00005"></a><span class="lineno"> 5</span> <span class="keyword">const</span> device T* in [[buffer(0)]],</div>
|
||||
<div class="line"><a id="l00006" name="l00006"></a><span class="lineno"> 6</span> device U* out [[buffer(1)]],</div>
|
||||
<div class="line"><a id="l00007" name="l00007"></a><span class="lineno"> 7</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& reduction_size [[buffer(2)]],</div>
|
||||
<div class="line"><a id="l00008" name="l00008"></a><span class="lineno"> 8</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& reduction_stride [[buffer(3)]],</div>
|
||||
<div class="line"><a id="l00009" name="l00009"></a><span class="lineno"> 9</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* shape [[buffer(4)]],</div>
|
||||
<div class="line"><a id="l00010" name="l00010"></a><span class="lineno"> 10</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* strides [[buffer(5)]],</div>
|
||||
<div class="line"><a id="l00011" name="l00011"></a><span class="lineno"> 11</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& ndim [[buffer(6)]],</div>
|
||||
<div class="line"><a id="l00012" name="l00012"></a><span class="lineno"> 12</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* reduce_shape [[buffer(7)]],</div>
|
||||
<div class="line"><a id="l00013" name="l00013"></a><span class="lineno"> 13</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* reduce_strides [[buffer(8)]],</div>
|
||||
<div class="line"><a id="l00014" name="l00014"></a><span class="lineno"> 14</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& reduce_ndim [[buffer(9)]],</div>
|
||||
<div class="line"><a id="l00015" name="l00015"></a><span class="lineno"> 15</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& non_col_reductions [[buffer(10)]],</div>
|
||||
<div class="line"><a id="l00016" name="l00016"></a><span class="lineno"> 16</span> uint3 gid [[threadgroup_position_in_grid]],</div>
|
||||
<div class="line"><a id="l00017" name="l00017"></a><span class="lineno"> 17</span> uint3 gsize [[threadgroups_per_grid]],</div>
|
||||
<div class="line"><a id="l00018" name="l00018"></a><span class="lineno"> 18</span> uint3 lid [[thread_position_in_threadgroup]],</div>
|
||||
<div class="line"><a id="l00019" name="l00019"></a><span class="lineno"> 19</span> uint3 lsize [[threads_per_threadgroup]]) {</div>
|
||||
<div class="line"><a id="l00020" name="l00020"></a><span class="lineno"> 20</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> n_reads = 4;</div>
|
||||
<div class="line"><a id="l00021" name="l00021"></a><span class="lineno"> 21</span> Op <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>;</div>
|
||||
<div class="line"><a id="l00022" name="l00022"></a><span class="lineno"> 22</span> <a class="code hl_struct" href="structlooped__elem__to__loc.html">looped_elem_to_loc<NDIMS></a> loop;</div>
|
||||
<div class="line"><a id="l00023" name="l00023"></a><span class="lineno"> 23</span> <span class="keyword">const</span> device T* row;</div>
|
||||
<div class="line"><a id="l00024" name="l00024"></a><span class="lineno"> 24</span> </div>
|
||||
<div class="line"><a id="l00025" name="l00025"></a><span class="lineno"> 25</span> U totals[n_reads];</div>
|
||||
<div class="line"><a id="l00026" name="l00026"></a><span class="lineno"> 26</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00027" name="l00027"></a><span class="lineno"> 27</span> totals[i] = Op::init;</div>
|
||||
<div class="line"><a id="l00028" name="l00028"></a><span class="lineno"> 28</span> }</div>
|
||||
<div class="line"><a id="l00029" name="l00029"></a><span class="lineno"> 29</span> </div>
|
||||
<div class="line"><a id="l00030" name="l00030"></a><span class="lineno"> 30</span> <span class="keywordtype">size_t</span> column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads;</div>
|
||||
<div class="line"><a id="l00031" name="l00031"></a><span class="lineno"> 31</span> <span class="keywordflow">if</span> (column >= reduction_stride) {</div>
|
||||
<div class="line"><a id="l00032" name="l00032"></a><span class="lineno"> 32</span> <span class="keywordflow">return</span>;</div>
|
||||
<div class="line"><a id="l00033" name="l00033"></a><span class="lineno"> 33</span> }</div>
|
||||
<div class="line"><a id="l00034" name="l00034"></a><span class="lineno"> 34</span> <span class="keywordtype">bool</span> safe = column + n_reads <= reduction_stride;</div>
|
||||
<div class="line"><a id="l00035" name="l00035"></a><span class="lineno"> 35</span> </div>
|
||||
<div class="line"><a id="l00036" name="l00036"></a><span class="lineno"> 36</span> <span class="keywordtype">size_t</span> out_idx = gid.y + gsize.y * size_t(gid.z);</div>
|
||||
<div class="line"><a id="l00037" name="l00037"></a><span class="lineno"> 37</span> <span class="keywordtype">size_t</span> in_idx = <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a8fd0c8fc6058e650fc99bca8b6acd7d1">elem_to_loc</a>(out_idx, shape, strides, ndim);</div>
|
||||
<div class="line"><a id="l00038" name="l00038"></a><span class="lineno"> 38</span> in += in_idx + column;</div>
|
||||
<div class="line"><a id="l00039" name="l00039"></a><span class="lineno"> 39</span> </div>
|
||||
<div class="line"><a id="l00040" name="l00040"></a><span class="lineno"> 40</span> <span class="keywordtype">size_t</span> total_rows = non_col_reductions * reduction_size;</div>
|
||||
<div class="line"><a id="l00041" name="l00041"></a><span class="lineno"> 41</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(lid.y, reduce_shape, reduce_strides);</div>
|
||||
<div class="line"><a id="l00042" name="l00042"></a><span class="lineno"> 42</span> <span class="keywordflow">for</span> (<span class="keywordtype">size_t</span> r = lid.y; r < total_rows; r += lsize.y) {</div>
|
||||
<div class="line"><a id="l00043" name="l00043"></a><span class="lineno"> 43</span> row = in + loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#accc6d4957a8aeb38f5062754793b74d2">location</a>(r, reduce_shape, reduce_strides, reduce_ndim);</div>
|
||||
<div class="line"><a id="l00044" name="l00044"></a><span class="lineno"> 44</span> <span class="keywordflow">if</span> (safe) {</div>
|
||||
<div class="line"><a id="l00045" name="l00045"></a><span class="lineno"> 45</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00046" name="l00046"></a><span class="lineno"> 46</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(<span class="keyword">static_cast<</span>U<span class="keyword">></span>(row[i]), totals[i]);</div>
|
||||
<div class="line"><a id="l00047" name="l00047"></a><span class="lineno"> 47</span> }</div>
|
||||
<div class="line"><a id="l00048" name="l00048"></a><span class="lineno"> 48</span> } <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00049" name="l00049"></a><span class="lineno"> 49</span> U vals[n_reads];</div>
|
||||
<div class="line"><a id="l00050" name="l00050"></a><span class="lineno"> 50</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00051" name="l00051"></a><span class="lineno"> 51</span> vals[i] =</div>
|
||||
<div class="line"><a id="l00052" name="l00052"></a><span class="lineno"> 52</span> (column + i < reduction_stride) ? static_cast<U>(row[i]) : <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>.init;</div>
|
||||
<div class="line"><a id="l00053" name="l00053"></a><span class="lineno"> 53</span> }</div>
|
||||
<div class="line"><a id="l00054" name="l00054"></a><span class="lineno"> 54</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00055" name="l00055"></a><span class="lineno"> 55</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(vals[i], totals[i]);</div>
|
||||
<div class="line"><a id="l00056" name="l00056"></a><span class="lineno"> 56</span> }</div>
|
||||
<div class="line"><a id="l00057" name="l00057"></a><span class="lineno"> 57</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> k = 0; k < extra; k++) {</div>
|
||||
<div class="line"><a id="l00058" name="l00058"></a><span class="lineno"> 58</span> totals[blocks * N_READS + k] =</div>
|
||||
<div class="line"><a id="l00059" name="l00059"></a><span class="lineno"> 59</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(totals[blocks * N_READS + k],</div>
|
||||
<div class="line"><a id="l00060" name="l00060"></a><span class="lineno"> 60</span> <span class="keyword">static_cast<</span>U<span class="keyword">></span>(row[i * stride + blocks * N_READS + k]));</div>
|
||||
<div class="line"><a id="l00061" name="l00061"></a><span class="lineno"> 61</span> }</div>
|
||||
<div class="line"><a id="l00062" name="l00062"></a><span class="lineno"> 62</span> }</div>
|
||||
<div class="line"><a id="l00063" name="l00063"></a><span class="lineno"> 63</span> </div>
|
||||
<div class="line"><a id="l00064" name="l00064"></a><span class="lineno"> 64</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(reduce_shape, reduce_strides);</div>
|
||||
<div class="line"><a id="l00065" name="l00065"></a><span class="lineno"> 65</span> }</div>
|
||||
<div class="line"><a id="l00066" name="l00066"></a><span class="lineno"> 66</span> out += out_idx * reduction_stride;</div>
|
||||
<div class="line"><a id="l00067" name="l00067"></a><span class="lineno"> 67</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> j = 0; j < stride; j++) {</div>
|
||||
<div class="line"><a id="l00068" name="l00068"></a><span class="lineno"> 68</span> out[j] = totals[j];</div>
|
||||
<div class="line"><a id="l00069" name="l00069"></a><span class="lineno"> 69</span> }</div>
|
||||
<div class="line"><a id="l00070" name="l00070"></a><span class="lineno"> 70</span> }</div>
|
||||
<div class="line"><a id="l00071" name="l00071"></a><span class="lineno"> 71</span> </div>
|
||||
<div class="line"><a id="l00072" name="l00072"></a><span class="lineno"> 72</span> <span class="comment">// Case 2: Long row small column</span></div>
|
||||
<div class="line"><a id="l00073" name="l00073"></a><span class="lineno"> 73</span> <span class="keywordflow">else</span> <span class="keywordflow">if</span> (reduction_size * non_col_reductions < 32) {</div>
|
||||
<div class="line"><a id="l00074" name="l00074"></a><span class="lineno"> 74</span> U totals[N_READS];</div>
|
||||
<div class="line"><a id="l00075" name="l00075"></a><span class="lineno"> 75</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < N_READS; i++) {</div>
|
||||
<div class="line"><a id="l00076" name="l00076"></a><span class="lineno"> 76</span> totals[i] = Op::init;</div>
|
||||
<div class="line"><a id="l00057" name="l00057"></a><span class="lineno"> 57</span> }</div>
|
||||
<div class="line"><a id="l00058" name="l00058"></a><span class="lineno"> 58</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(lsize.y, reduce_shape, reduce_strides);</div>
|
||||
<div class="line"><a id="l00059" name="l00059"></a><span class="lineno"> 59</span> }</div>
|
||||
<div class="line"><a id="l00060" name="l00060"></a><span class="lineno"> 60</span> </div>
|
||||
<div class="line"><a id="l00061" name="l00061"></a><span class="lineno"> 61</span> <span class="keywordflow">if</span> (lsize.y > 1) {</div>
|
||||
<div class="line"><a id="l00062" name="l00062"></a><span class="lineno"> 62</span> <span class="comment">// lsize.y should be <= 8</span></div>
|
||||
<div class="line"><a id="l00063" name="l00063"></a><span class="lineno"> 63</span> threadgroup U shared_vals[32 * 8 * n_reads];</div>
|
||||
<div class="line"><a id="l00064" name="l00064"></a><span class="lineno"> 64</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00065" name="l00065"></a><span class="lineno"> 65</span> shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i];</div>
|
||||
<div class="line"><a id="l00066" name="l00066"></a><span class="lineno"> 66</span> }</div>
|
||||
<div class="line"><a id="l00067" name="l00067"></a><span class="lineno"> 67</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00068" name="l00068"></a><span class="lineno"> 68</span> <span class="keywordflow">if</span> (lid.y == 0) {</div>
|
||||
<div class="line"><a id="l00069" name="l00069"></a><span class="lineno"> 69</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00070" name="l00070"></a><span class="lineno"> 70</span> totals[i] = shared_vals[lid.x * n_reads + i];</div>
|
||||
<div class="line"><a id="l00071" name="l00071"></a><span class="lineno"> 71</span> }</div>
|
||||
<div class="line"><a id="l00072" name="l00072"></a><span class="lineno"> 72</span> <span class="keywordflow">for</span> (uint j = 1; j < lsize.y; j++) {</div>
|
||||
<div class="line"><a id="l00073" name="l00073"></a><span class="lineno"> 73</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00074" name="l00074"></a><span class="lineno"> 74</span> totals[i] =</div>
|
||||
<div class="line"><a id="l00075" name="l00075"></a><span class="lineno"> 75</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i],</div>
|
||||
<div class="line"><a id="l00076" name="l00076"></a><span class="lineno"> 76</span> totals[i]);</div>
|
||||
<div class="line"><a id="l00077" name="l00077"></a><span class="lineno"> 77</span> }</div>
|
||||
<div class="line"><a id="l00078" name="l00078"></a><span class="lineno"> 78</span> </div>
|
||||
<div class="line"><a id="l00079" name="l00079"></a><span class="lineno"> 79</span> <span class="keywordtype">short</span> size = reduction_size;</div>
|
||||
<div class="line"><a id="l00080" name="l00080"></a><span class="lineno"> 80</span> <span class="keywordtype">size_t</span> offset = size_t(tid.x) * N_READS;</div>
|
||||
<div class="line"><a id="l00081" name="l00081"></a><span class="lineno"> 81</span> <span class="keywordtype">bool</span> safe = offset + N_READS <= reduction_stride;</div>
|
||||
<div class="line"><a id="l00082" name="l00082"></a><span class="lineno"> 82</span> <span class="keywordtype">short</span> extra = reduction_stride - offset;</div>
|
||||
<div class="line"><a id="l00083" name="l00083"></a><span class="lineno"> 83</span> </div>
|
||||
<div class="line"><a id="l00084" name="l00084"></a><span class="lineno"> 84</span> <span class="keywordtype">size_t</span> out_idx = tid.y + tsize.z * size_t(tid.z);</div>
|
||||
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"> 85</span> in += <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a8fd0c8fc6058e650fc99bca8b6acd7d1">elem_to_loc</a>(out_idx, shape, strides, ndim) + offset;</div>
|
||||
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"> 86</span> </div>
|
||||
<div class="line"><a id="l00087" name="l00087"></a><span class="lineno"> 87</span> <span class="keywordflow">for</span> (uint r = 0; r < non_col_reductions; r++) {</div>
|
||||
<div class="line"><a id="l00088" name="l00088"></a><span class="lineno"> 88</span> row = in + loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#accc6d4957a8aeb38f5062754793b74d2">location</a>(r, reduce_shape, reduce_strides, reduce_ndim);</div>
|
||||
<div class="line"><a id="l00089" name="l00089"></a><span class="lineno"> 89</span> </div>
|
||||
<div class="line"><a id="l00090" name="l00090"></a><span class="lineno"> 90</span> <span class="keywordflow">if</span> (safe) {</div>
|
||||
<div class="line"><a id="l00091" name="l00091"></a><span class="lineno"> 91</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> i = 0; i < size; i++) {</div>
|
||||
<div class="line"><a id="l00092" name="l00092"></a><span class="lineno"> 92</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> j = 0; j < N_READS; j++) {</div>
|
||||
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"> 93</span> totals[j] =</div>
|
||||
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"> 94</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(<span class="keyword">static_cast<</span>U<span class="keyword">></span>(row[i * reduction_stride + j]), totals[j]);</div>
|
||||
<div class="line"><a id="l00095" name="l00095"></a><span class="lineno"> 95</span> }</div>
|
||||
<div class="line"><a id="l00096" name="l00096"></a><span class="lineno"> 96</span> }</div>
|
||||
<div class="line"><a id="l00097" name="l00097"></a><span class="lineno"> 97</span> } <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00098" name="l00098"></a><span class="lineno"> 98</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> i = 0; i < size; i++) {</div>
|
||||
<div class="line"><a id="l00099" name="l00099"></a><span class="lineno"> 99</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> j = 0; j < extra; j++) {</div>
|
||||
<div class="line"><a id="l00100" name="l00100"></a><span class="lineno"> 100</span> totals[j] =</div>
|
||||
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"> 101</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(<span class="keyword">static_cast<</span>U<span class="keyword">></span>(row[i * reduction_stride + j]), totals[j]);</div>
|
||||
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"> 102</span> }</div>
|
||||
<div class="line"><a id="l00103" name="l00103"></a><span class="lineno"> 103</span> }</div>
|
||||
<div class="line"><a id="l00104" name="l00104"></a><span class="lineno"> 104</span> }</div>
|
||||
<div class="line"><a id="l00105" name="l00105"></a><span class="lineno"> 105</span> </div>
|
||||
<div class="line"><a id="l00106" name="l00106"></a><span class="lineno"> 106</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(reduce_shape, reduce_strides);</div>
|
||||
<div class="line"><a id="l00107" name="l00107"></a><span class="lineno"> 107</span> }</div>
|
||||
<div class="line"><a id="l00108" name="l00108"></a><span class="lineno"> 108</span> out += out_idx * reduction_stride + offset;</div>
|
||||
<div class="line"><a id="l00109" name="l00109"></a><span class="lineno"> 109</span> <span class="keywordflow">if</span> (safe) {</div>
|
||||
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"> 110</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> i = 0; i < N_READS; i++) {</div>
|
||||
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"> 111</span> out[i] = totals[i];</div>
|
||||
<div class="line"><a id="l00112" name="l00112"></a><span class="lineno"> 112</span> }</div>
|
||||
<div class="line"><a id="l00113" name="l00113"></a><span class="lineno"> 113</span> } <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00114" name="l00114"></a><span class="lineno"> 114</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> i = 0; i < extra; i++) {</div>
|
||||
<div class="line"><a id="l00115" name="l00115"></a><span class="lineno"> 115</span> out[i] = totals[i];</div>
|
||||
<div class="line"><a id="l00116" name="l00116"></a><span class="lineno"> 116</span> }</div>
|
||||
<div class="line"><a id="l00117" name="l00117"></a><span class="lineno"> 117</span> }</div>
|
||||
<div class="line"><a id="l00118" name="l00118"></a><span class="lineno"> 118</span> }</div>
|
||||
<div class="line"><a id="l00119" name="l00119"></a><span class="lineno"> 119</span> </div>
|
||||
<div class="line"><a id="l00120" name="l00120"></a><span class="lineno"> 120</span> <span class="comment">// Case 3: Long row medium column</span></div>
|
||||
<div class="line"><a id="l00121" name="l00121"></a><span class="lineno"> 121</span> <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00122" name="l00122"></a><span class="lineno"> 122</span> threadgroup U shared_vals[1024];</div>
|
||||
<div class="line"><a id="l00123" name="l00123"></a><span class="lineno"> 123</span> U totals[N_READS];</div>
|
||||
<div class="line"><a id="l00124" name="l00124"></a><span class="lineno"> 124</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < N_READS; i++) {</div>
|
||||
<div class="line"><a id="l00125" name="l00125"></a><span class="lineno"> 125</span> totals[i] = Op::init;</div>
|
||||
<div class="line"><a id="l00126" name="l00126"></a><span class="lineno"> 126</span> }</div>
|
||||
<div class="line"><a id="l00127" name="l00127"></a><span class="lineno"> 127</span> </div>
|
||||
<div class="line"><a id="l00128" name="l00128"></a><span class="lineno"> 128</span> <span class="keywordtype">short</span> stride = reduction_stride;</div>
|
||||
<div class="line"><a id="l00129" name="l00129"></a><span class="lineno"> 129</span> <span class="keywordtype">short</span> lid = simd_group_id * <a class="code hl_variable" href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a> + simd_lane_id;</div>
|
||||
<div class="line"><a id="l00130" name="l00130"></a><span class="lineno"> 130</span> short2 tile((stride + N_READS - 1) / N_READS, 32);</div>
|
||||
<div class="line"><a id="l00131" name="l00131"></a><span class="lineno"> 131</span> short2 offset((lid % tile.x) * N_READS, lid / tile.x);</div>
|
||||
<div class="line"><a id="l00132" name="l00132"></a><span class="lineno"> 132</span> <span class="keywordtype">short</span> sm_stride = tile.x * N_READS;</div>
|
||||
<div class="line"><a id="l00133" name="l00133"></a><span class="lineno"> 133</span> <span class="keywordtype">bool</span> safe = offset.x + N_READS <= stride;</div>
|
||||
<div class="line"><a id="l00134" name="l00134"></a><span class="lineno"> 134</span> </div>
|
||||
<div class="line"><a id="l00135" name="l00135"></a><span class="lineno"> 135</span> <span class="keywordtype">size_t</span> out_idx = gid.y + gsize.y * size_t(gid.z);</div>
|
||||
<div class="line"><a id="l00136" name="l00136"></a><span class="lineno"> 136</span> in += <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a8fd0c8fc6058e650fc99bca8b6acd7d1">elem_to_loc</a>(out_idx, shape, strides, ndim) + offset.x;</div>
|
||||
<div class="line"><a id="l00137" name="l00137"></a><span class="lineno"> 137</span> </div>
|
||||
<div class="line"><a id="l00138" name="l00138"></a><span class="lineno"> 138</span> <span class="comment">// Read cooperatively and contiguously and aggregate the partial results.</span></div>
|
||||
<div class="line"><a id="l00139" name="l00139"></a><span class="lineno"> 139</span> <span class="keywordtype">size_t</span> total = non_col_reductions * reduction_size;</div>
|
||||
<div class="line"><a id="l00140" name="l00140"></a><span class="lineno"> 140</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(offset.y, reduce_shape, reduce_strides);</div>
|
||||
<div class="line"><a id="l00141" name="l00141"></a><span class="lineno"> 141</span> <span class="keywordflow">for</span> (<span class="keywordtype">size_t</span> r = offset.y; r < total; r += <a class="code hl_variable" href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a>) {</div>
|
||||
<div class="line"><a id="l00142" name="l00142"></a><span class="lineno"> 142</span> row = in + loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#accc6d4957a8aeb38f5062754793b74d2">location</a>(r, reduce_shape, reduce_strides, reduce_ndim);</div>
|
||||
<div class="line"><a id="l00143" name="l00143"></a><span class="lineno"> 143</span> </div>
|
||||
<div class="line"><a id="l00144" name="l00144"></a><span class="lineno"> 144</span> <span class="keywordflow">if</span> (safe) {</div>
|
||||
<div class="line"><a id="l00145" name="l00145"></a><span class="lineno"> 145</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < N_READS; i++) {</div>
|
||||
<div class="line"><a id="l00146" name="l00146"></a><span class="lineno"> 146</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(<span class="keyword">static_cast<</span>U<span class="keyword">></span>(row[i]), totals[i]);</div>
|
||||
<div class="line"><a id="l00147" name="l00147"></a><span class="lineno"> 147</span> }</div>
|
||||
<div class="line"><a id="l00148" name="l00148"></a><span class="lineno"> 148</span> } <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00149" name="l00149"></a><span class="lineno"> 149</span> U vals[N_READS];</div>
|
||||
<div class="line"><a id="l00150" name="l00150"></a><span class="lineno"> 150</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < N_READS; i++) {</div>
|
||||
<div class="line"><a id="l00151" name="l00151"></a><span class="lineno"> 151</span> vals[i] = (offset.x + i < stride) ? static_cast<U>(row[i]) : <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>.init;</div>
|
||||
<div class="line"><a id="l00152" name="l00152"></a><span class="lineno"> 152</span> }</div>
|
||||
<div class="line"><a id="l00153" name="l00153"></a><span class="lineno"> 153</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < N_READS; i++) {</div>
|
||||
<div class="line"><a id="l00154" name="l00154"></a><span class="lineno"> 154</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(vals[i], totals[i]);</div>
|
||||
<div class="line"><a id="l00155" name="l00155"></a><span class="lineno"> 155</span> }</div>
|
||||
<div class="line"><a id="l00156" name="l00156"></a><span class="lineno"> 156</span> }</div>
|
||||
<div class="line"><a id="l00157" name="l00157"></a><span class="lineno"> 157</span> </div>
|
||||
<div class="line"><a id="l00158" name="l00158"></a><span class="lineno"> 158</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(<a class="code hl_variable" href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a>, reduce_shape, reduce_strides);</div>
|
||||
<div class="line"><a id="l00159" name="l00159"></a><span class="lineno"> 159</span> }</div>
|
||||
<div class="line"><a id="l00160" name="l00160"></a><span class="lineno"> 160</span> </div>
|
||||
<div class="line"><a id="l00161" name="l00161"></a><span class="lineno"> 161</span> <span class="comment">// Each thread holds N_READS partial results but the simdgroups are not</span></div>
|
||||
<div class="line"><a id="l00162" name="l00162"></a><span class="lineno"> 162</span> <span class="comment">// aligned to do the reduction across the simdgroup so we write our results</span></div>
|
||||
<div class="line"><a id="l00163" name="l00163"></a><span class="lineno"> 163</span> <span class="comment">// in the shared memory and read them back according to the simdgroup.</span></div>
|
||||
<div class="line"><a id="l00164" name="l00164"></a><span class="lineno"> 164</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < N_READS; i++) {</div>
|
||||
<div class="line"><a id="l00165" name="l00165"></a><span class="lineno"> 165</span> shared_vals[offset.y * sm_stride + offset.x + i] = totals[i];</div>
|
||||
<div class="line"><a id="l00166" name="l00166"></a><span class="lineno"> 166</span> }</div>
|
||||
<div class="line"><a id="l00167" name="l00167"></a><span class="lineno"> 167</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00168" name="l00168"></a><span class="lineno"> 168</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < N_READS; i++) {</div>
|
||||
<div class="line"><a id="l00169" name="l00169"></a><span class="lineno"> 169</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>.simd_reduce(</div>
|
||||
<div class="line"><a id="l00170" name="l00170"></a><span class="lineno"> 170</span> shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]);</div>
|
||||
<div class="line"><a id="l00171" name="l00171"></a><span class="lineno"> 171</span> }</div>
|
||||
<div class="line"><a id="l00172" name="l00172"></a><span class="lineno"> 172</span> </div>
|
||||
<div class="line"><a id="l00173" name="l00173"></a><span class="lineno"> 173</span> <span class="comment">// Write the output.</span></div>
|
||||
<div class="line"><a id="l00174" name="l00174"></a><span class="lineno"> 174</span> <span class="keywordflow">if</span> (simd_lane_id == 0) {</div>
|
||||
<div class="line"><a id="l00175" name="l00175"></a><span class="lineno"> 175</span> <span class="keywordtype">short</span> column = simd_group_id * N_READS;</div>
|
||||
<div class="line"><a id="l00176" name="l00176"></a><span class="lineno"> 176</span> out += out_idx * reduction_stride + column;</div>
|
||||
<div class="line"><a id="l00177" name="l00177"></a><span class="lineno"> 177</span> <span class="keywordflow">if</span> (column + N_READS <= stride) {</div>
|
||||
<div class="line"><a id="l00178" name="l00178"></a><span class="lineno"> 178</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < N_READS; i++) {</div>
|
||||
<div class="line"><a id="l00179" name="l00179"></a><span class="lineno"> 179</span> out[i] = totals[i];</div>
|
||||
<div class="line"><a id="l00180" name="l00180"></a><span class="lineno"> 180</span> }</div>
|
||||
<div class="line"><a id="l00181" name="l00181"></a><span class="lineno"> 181</span> } <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00182" name="l00182"></a><span class="lineno"> 182</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; column + i < stride; i++) {</div>
|
||||
<div class="line"><a id="l00183" name="l00183"></a><span class="lineno"> 183</span> out[i] = totals[i];</div>
|
||||
<div class="line"><a id="l00184" name="l00184"></a><span class="lineno"> 184</span> }</div>
|
||||
<div class="line"><a id="l00185" name="l00185"></a><span class="lineno"> 185</span> }</div>
|
||||
<div class="line"><a id="l00186" name="l00186"></a><span class="lineno"> 186</span> }</div>
|
||||
<div class="line"><a id="l00187" name="l00187"></a><span class="lineno"> 187</span> }</div>
|
||||
<div class="line"><a id="l00188" name="l00188"></a><span class="lineno"> 188</span>}</div>
|
||||
<div class="line"><a id="l00078" name="l00078"></a><span class="lineno"> 78</span> }</div>
|
||||
<div class="line"><a id="l00079" name="l00079"></a><span class="lineno"> 79</span> }</div>
|
||||
<div class="line"><a id="l00080" name="l00080"></a><span class="lineno"> 80</span> }</div>
|
||||
<div class="line"><a id="l00081" name="l00081"></a><span class="lineno"> 81</span> </div>
|
||||
<div class="line"><a id="l00082" name="l00082"></a><span class="lineno"> 82</span> <span class="keywordflow">if</span> (lid.y == 0) {</div>
|
||||
<div class="line"><a id="l00083" name="l00083"></a><span class="lineno"> 83</span> out += out_idx * reduction_stride + column;</div>
|
||||
<div class="line"><a id="l00084" name="l00084"></a><span class="lineno"> 84</span> <span class="keywordflow">if</span> (safe) {</div>
|
||||
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"> 85</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"> 86</span> out[i] = totals[i];</div>
|
||||
<div class="line"><a id="l00087" name="l00087"></a><span class="lineno"> 87</span> }</div>
|
||||
<div class="line"><a id="l00088" name="l00088"></a><span class="lineno"> 88</span> } <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00089" name="l00089"></a><span class="lineno"> 89</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; column + i < reduction_stride; i++) {</div>
|
||||
<div class="line"><a id="l00090" name="l00090"></a><span class="lineno"> 90</span> out[i] = totals[i];</div>
|
||||
<div class="line"><a id="l00091" name="l00091"></a><span class="lineno"> 91</span> }</div>
|
||||
<div class="line"><a id="l00092" name="l00092"></a><span class="lineno"> 92</span> }</div>
|
||||
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"> 93</span> }</div>
|
||||
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"> 94</span>}</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00189" name="l00189"></a><span class="lineno"> 189</span> </div>
|
||||
<div class="line"><a id="l00201" name="l00201"></a><span class="lineno"> 201</span><span class="keyword">template</span> <<span class="keyword">typename</span> T, <span class="keyword">typename</span> U, <span class="keyword">typename</span> Op, <span class="keywordtype">int</span> NDIMS, <span class="keywordtype">int</span> BM, <span class="keywordtype">int</span> BN></div>
|
||||
<div class="foldopen" id="foldopen00202" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00202" name="l00202"></a><span class="lineno"><a class="line" href="reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385"> 202</a></span>[[kernel]] <span class="keywordtype">void</span> <a class="code hl_function" href="reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385">col_reduce_looped</a>(</div>
|
||||
<div class="line"><a id="l00203" name="l00203"></a><span class="lineno"> 203</span> <span class="keyword">const</span> device T* in [[buffer(0)]],</div>
|
||||
<div class="line"><a id="l00204" name="l00204"></a><span class="lineno"> 204</span> device U* out [[buffer(1)]],</div>
|
||||
<div class="line"><a id="l00205" name="l00205"></a><span class="lineno"> 205</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& reduction_size [[buffer(2)]],</div>
|
||||
<div class="line"><a id="l00206" name="l00206"></a><span class="lineno"> 206</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& reduction_stride [[buffer(3)]],</div>
|
||||
<div class="line"><a id="l00207" name="l00207"></a><span class="lineno"> 207</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* shape [[buffer(4)]],</div>
|
||||
<div class="line"><a id="l00208" name="l00208"></a><span class="lineno"> 208</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* strides [[buffer(5)]],</div>
|
||||
<div class="line"><a id="l00209" name="l00209"></a><span class="lineno"> 209</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& ndim [[buffer(6)]],</div>
|
||||
<div class="line"><a id="l00210" name="l00210"></a><span class="lineno"> 210</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* reduce_shape [[buffer(7)]],</div>
|
||||
<div class="line"><a id="l00211" name="l00211"></a><span class="lineno"> 211</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* reduce_strides [[buffer(8)]],</div>
|
||||
<div class="line"><a id="l00212" name="l00212"></a><span class="lineno"> 212</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& reduce_ndim [[buffer(9)]],</div>
|
||||
<div class="line"><a id="l00213" name="l00213"></a><span class="lineno"> 213</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& non_col_reductions [[buffer(10)]],</div>
|
||||
<div class="line"><a id="l00214" name="l00214"></a><span class="lineno"> 214</span> uint3 gid [[threadgroup_position_in_grid]],</div>
|
||||
<div class="line"><a id="l00215" name="l00215"></a><span class="lineno"> 215</span> uint3 gsize [[threadgroups_per_grid]],</div>
|
||||
<div class="line"><a id="l00216" name="l00216"></a><span class="lineno"> 216</span> uint simd_lane_id [[thread_index_in_simdgroup]],</div>
|
||||
<div class="line"><a id="l00217" name="l00217"></a><span class="lineno"> 217</span> uint simd_group_id [[simdgroup_index_in_threadgroup]]) {</div>
|
||||
<div class="line"><a id="l00218" name="l00218"></a><span class="lineno"> 218</span> Op <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>;</div>
|
||||
<div class="line"><a id="l00219" name="l00219"></a><span class="lineno"> 219</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> n_simdgroups = 4;</div>
|
||||
<div class="line"><a id="l00220" name="l00220"></a><span class="lineno"> 220</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> tgp_size = n_simdgroups * <a class="code hl_variable" href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a>;</div>
|
||||
<div class="line"><a id="l00221" name="l00221"></a><span class="lineno"> 221</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> n_reads = (BM * BN) / tgp_size;</div>
|
||||
<div class="line"><a id="l00222" name="l00222"></a><span class="lineno"> 222</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> n_read_blocks = BN / n_reads;</div>
|
||||
<div class="line"><a id="l00223" name="l00223"></a><span class="lineno"> 223</span> </div>
|
||||
<div class="line"><a id="l00224" name="l00224"></a><span class="lineno"> 224</span> threadgroup U shared_vals[BN * BM];</div>
|
||||
<div class="line"><a id="l00225" name="l00225"></a><span class="lineno"> 225</span> U totals[n_reads];</div>
|
||||
<div class="line"><a id="l00226" name="l00226"></a><span class="lineno"> 226</span> <a class="code hl_struct" href="structlooped__elem__to__loc.html">looped_elem_to_loc<NDIMS></a> loop;</div>
|
||||
<div class="line"><a id="l00227" name="l00227"></a><span class="lineno"> 227</span> <span class="keyword">const</span> device T* row;</div>
|
||||
<div class="line"><a id="l00228" name="l00228"></a><span class="lineno"> 228</span> </div>
|
||||
<div class="line"><a id="l00229" name="l00229"></a><span class="lineno"> 229</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00230" name="l00230"></a><span class="lineno"> 230</span> totals[i] = Op::init;</div>
|
||||
<div class="line"><a id="l00231" name="l00231"></a><span class="lineno"> 231</span> }</div>
|
||||
<div class="line"><a id="l00232" name="l00232"></a><span class="lineno"> 232</span> </div>
|
||||
<div class="line"><a id="l00233" name="l00233"></a><span class="lineno"> 233</span> <span class="keywordtype">short</span> lid = simd_group_id * <a class="code hl_variable" href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a> + simd_lane_id;</div>
|
||||
<div class="line"><a id="l00234" name="l00234"></a><span class="lineno"> 234</span> short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);</div>
|
||||
<div class="line"><a id="l00235" name="l00235"></a><span class="lineno"> 235</span> <span class="keywordtype">size_t</span> column = BN * gid.x + offset.x;</div>
|
||||
<div class="line"><a id="l00236" name="l00236"></a><span class="lineno"> 236</span> <span class="keywordtype">bool</span> safe = column + n_reads <= reduction_stride;</div>
|
||||
<div class="line"><a id="l00237" name="l00237"></a><span class="lineno"> 237</span> </div>
|
||||
<div class="line"><a id="l00238" name="l00238"></a><span class="lineno"> 238</span> <span class="keywordtype">size_t</span> out_idx = gid.y + gsize.y * size_t(gid.z);</div>
|
||||
<div class="line"><a id="l00239" name="l00239"></a><span class="lineno"> 239</span> <span class="keywordtype">size_t</span> in_idx = <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a8fd0c8fc6058e650fc99bca8b6acd7d1">elem_to_loc</a>(out_idx, shape, strides, ndim);</div>
|
||||
<div class="line"><a id="l00240" name="l00240"></a><span class="lineno"> 240</span> in += in_idx + column;</div>
|
||||
<div class="line"><a id="l00241" name="l00241"></a><span class="lineno"> 241</span> </div>
|
||||
<div class="line"><a id="l00242" name="l00242"></a><span class="lineno"> 242</span> <span class="keywordtype">size_t</span> total = non_col_reductions * reduction_size;</div>
|
||||
<div class="line"><a id="l00243" name="l00243"></a><span class="lineno"> 243</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(offset.y, reduce_shape, reduce_strides);</div>
|
||||
<div class="line"><a id="l00244" name="l00244"></a><span class="lineno"> 244</span> <span class="keywordflow">for</span> (<span class="keywordtype">size_t</span> r = offset.y; r < total; r += BM) {</div>
|
||||
<div class="line"><a id="l00245" name="l00245"></a><span class="lineno"> 245</span> row = in + loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#accc6d4957a8aeb38f5062754793b74d2">location</a>(r, reduce_shape, reduce_strides, reduce_ndim);</div>
|
||||
<div class="line"><a id="l00246" name="l00246"></a><span class="lineno"> 246</span> </div>
|
||||
<div class="line"><a id="l00247" name="l00247"></a><span class="lineno"> 247</span> <span class="keywordflow">if</span> (safe) {</div>
|
||||
<div class="line"><a id="l00248" name="l00248"></a><span class="lineno"> 248</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00249" name="l00249"></a><span class="lineno"> 249</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(<span class="keyword">static_cast<</span>U<span class="keyword">></span>(row[i]), totals[i]);</div>
|
||||
<div class="line"><a id="l00095" name="l00095"></a><span class="lineno"> 95</span> </div>
|
||||
<div class="line"><a id="l00096" name="l00096"></a><span class="lineno"> 96</span><span class="keyword">template</span> <<span class="keyword">typename</span> T, <span class="keyword">typename</span> U, <span class="keyword">typename</span> Op, <span class="keywordtype">int</span> NDIMS></div>
|
||||
<div class="foldopen" id="foldopen00097" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00097" name="l00097"></a><span class="lineno"><a class="line" href="reduce__col_8h.html#a5b4f4c4c247ad341ff8d31dcbbbce0eb"> 97</a></span>[[kernel]] <span class="keywordtype">void</span> <a class="code hl_function" href="reduce__col_8h.html#a5b4f4c4c247ad341ff8d31dcbbbce0eb">col_reduce_longcolumn</a>(</div>
|
||||
<div class="line"><a id="l00098" name="l00098"></a><span class="lineno"> 98</span> <span class="keyword">const</span> device T* in [[buffer(0)]],</div>
|
||||
<div class="line"><a id="l00099" name="l00099"></a><span class="lineno"> 99</span> device U* out [[buffer(1)]],</div>
|
||||
<div class="line"><a id="l00100" name="l00100"></a><span class="lineno"> 100</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& reduction_size [[buffer(2)]],</div>
|
||||
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"> 101</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& reduction_stride [[buffer(3)]],</div>
|
||||
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"> 102</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* shape [[buffer(4)]],</div>
|
||||
<div class="line"><a id="l00103" name="l00103"></a><span class="lineno"> 103</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* strides [[buffer(5)]],</div>
|
||||
<div class="line"><a id="l00104" name="l00104"></a><span class="lineno"> 104</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& ndim [[buffer(6)]],</div>
|
||||
<div class="line"><a id="l00105" name="l00105"></a><span class="lineno"> 105</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* reduce_shape [[buffer(7)]],</div>
|
||||
<div class="line"><a id="l00106" name="l00106"></a><span class="lineno"> 106</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* reduce_strides [[buffer(8)]],</div>
|
||||
<div class="line"><a id="l00107" name="l00107"></a><span class="lineno"> 107</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& reduce_ndim [[buffer(9)]],</div>
|
||||
<div class="line"><a id="l00108" name="l00108"></a><span class="lineno"> 108</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& non_col_reductions [[buffer(10)]],</div>
|
||||
<div class="line"><a id="l00109" name="l00109"></a><span class="lineno"> 109</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& out_size [[buffer(11)]],</div>
|
||||
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"> 110</span> uint3 gid [[threadgroup_position_in_grid]],</div>
|
||||
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"> 111</span> uint3 gsize [[threadgroups_per_grid]],</div>
|
||||
<div class="line"><a id="l00112" name="l00112"></a><span class="lineno"> 112</span> uint3 lid [[thread_position_in_threadgroup]],</div>
|
||||
<div class="line"><a id="l00113" name="l00113"></a><span class="lineno"> 113</span> uint3 lsize [[threads_per_threadgroup]]) {</div>
|
||||
<div class="line"><a id="l00114" name="l00114"></a><span class="lineno"> 114</span> Op <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>;</div>
|
||||
<div class="line"><a id="l00115" name="l00115"></a><span class="lineno"> 115</span> <a class="code hl_struct" href="structlooped__elem__to__loc.html">looped_elem_to_loc<NDIMS></a> loop;</div>
|
||||
<div class="line"><a id="l00116" name="l00116"></a><span class="lineno"> 116</span> <span class="keyword">const</span> device T* row;</div>
|
||||
<div class="line"><a id="l00117" name="l00117"></a><span class="lineno"> 117</span> </div>
|
||||
<div class="line"><a id="l00118" name="l00118"></a><span class="lineno"> 118</span> <span class="keywordtype">size_t</span> out_idx = gid.x + gsize.x * size_t(gid.y);</div>
|
||||
<div class="line"><a id="l00119" name="l00119"></a><span class="lineno"> 119</span> <span class="keywordtype">size_t</span> in_idx = <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a8fd0c8fc6058e650fc99bca8b6acd7d1">elem_to_loc</a>(out_idx, shape, strides, ndim);</div>
|
||||
<div class="line"><a id="l00120" name="l00120"></a><span class="lineno"> 120</span> in += in_idx + lid.x;</div>
|
||||
<div class="line"><a id="l00121" name="l00121"></a><span class="lineno"> 121</span> </div>
|
||||
<div class="line"><a id="l00122" name="l00122"></a><span class="lineno"> 122</span> U total = Op::init;</div>
|
||||
<div class="line"><a id="l00123" name="l00123"></a><span class="lineno"> 123</span> <span class="keywordtype">size_t</span> total_rows = non_col_reductions * reduction_size;</div>
|
||||
<div class="line"><a id="l00124" name="l00124"></a><span class="lineno"> 124</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);</div>
|
||||
<div class="line"><a id="l00125" name="l00125"></a><span class="lineno"> 125</span> <span class="keywordflow">for</span> (<span class="keywordtype">size_t</span> r = gid.z * lsize.y + lid.y; r < total_rows;</div>
|
||||
<div class="line"><a id="l00126" name="l00126"></a><span class="lineno"> 126</span> r += lsize.y * gsize.z) {</div>
|
||||
<div class="line"><a id="l00127" name="l00127"></a><span class="lineno"> 127</span> row = in + loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#accc6d4957a8aeb38f5062754793b74d2">location</a>(r, reduce_shape, reduce_strides, reduce_ndim);</div>
|
||||
<div class="line"><a id="l00128" name="l00128"></a><span class="lineno"> 128</span> total = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(<span class="keyword">static_cast<</span>U<span class="keyword">></span>(*row), total);</div>
|
||||
<div class="line"><a id="l00129" name="l00129"></a><span class="lineno"> 129</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(lsize.y * gsize.z, reduce_shape, reduce_strides);</div>
|
||||
<div class="line"><a id="l00130" name="l00130"></a><span class="lineno"> 130</span> }</div>
|
||||
<div class="line"><a id="l00131" name="l00131"></a><span class="lineno"> 131</span> </div>
|
||||
<div class="line"><a id="l00132" name="l00132"></a><span class="lineno"> 132</span> threadgroup U shared_vals[32 * 32];</div>
|
||||
<div class="line"><a id="l00133" name="l00133"></a><span class="lineno"> 133</span> shared_vals[lid.y * lsize.x + lid.x] = total;</div>
|
||||
<div class="line"><a id="l00134" name="l00134"></a><span class="lineno"> 134</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00135" name="l00135"></a><span class="lineno"> 135</span> <span class="keywordflow">if</span> (lid.y == 0) {</div>
|
||||
<div class="line"><a id="l00136" name="l00136"></a><span class="lineno"> 136</span> <span class="keywordflow">for</span> (uint i = 1; i < lsize.y; i++) {</div>
|
||||
<div class="line"><a id="l00137" name="l00137"></a><span class="lineno"> 137</span> total = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(total, shared_vals[i * lsize.x + lid.x]);</div>
|
||||
<div class="line"><a id="l00138" name="l00138"></a><span class="lineno"> 138</span> }</div>
|
||||
<div class="line"><a id="l00139" name="l00139"></a><span class="lineno"> 139</span> out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total;</div>
|
||||
<div class="line"><a id="l00140" name="l00140"></a><span class="lineno"> 140</span> }</div>
|
||||
<div class="line"><a id="l00141" name="l00141"></a><span class="lineno"> 141</span>}</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00142" name="l00142"></a><span class="lineno"> 142</span> </div>
|
||||
<div class="line"><a id="l00154" name="l00154"></a><span class="lineno"> 154</span><span class="keyword">template</span> <<span class="keyword">typename</span> T, <span class="keyword">typename</span> U, <span class="keyword">typename</span> Op, <span class="keywordtype">int</span> NDIMS, <span class="keywordtype">int</span> BM, <span class="keywordtype">int</span> BN></div>
|
||||
<div class="foldopen" id="foldopen00155" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00155" name="l00155"></a><span class="lineno"><a class="line" href="reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385"> 155</a></span>[[kernel]] <span class="keywordtype">void</span> <a class="code hl_function" href="reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385">col_reduce_looped</a>(</div>
|
||||
<div class="line"><a id="l00156" name="l00156"></a><span class="lineno"> 156</span> <span class="keyword">const</span> device T* in [[buffer(0)]],</div>
|
||||
<div class="line"><a id="l00157" name="l00157"></a><span class="lineno"> 157</span> device U* out [[buffer(1)]],</div>
|
||||
<div class="line"><a id="l00158" name="l00158"></a><span class="lineno"> 158</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& reduction_size [[buffer(2)]],</div>
|
||||
<div class="line"><a id="l00159" name="l00159"></a><span class="lineno"> 159</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& reduction_stride [[buffer(3)]],</div>
|
||||
<div class="line"><a id="l00160" name="l00160"></a><span class="lineno"> 160</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* shape [[buffer(4)]],</div>
|
||||
<div class="line"><a id="l00161" name="l00161"></a><span class="lineno"> 161</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* strides [[buffer(5)]],</div>
|
||||
<div class="line"><a id="l00162" name="l00162"></a><span class="lineno"> 162</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& ndim [[buffer(6)]],</div>
|
||||
<div class="line"><a id="l00163" name="l00163"></a><span class="lineno"> 163</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* reduce_shape [[buffer(7)]],</div>
|
||||
<div class="line"><a id="l00164" name="l00164"></a><span class="lineno"> 164</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* reduce_strides [[buffer(8)]],</div>
|
||||
<div class="line"><a id="l00165" name="l00165"></a><span class="lineno"> 165</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& reduce_ndim [[buffer(9)]],</div>
|
||||
<div class="line"><a id="l00166" name="l00166"></a><span class="lineno"> 166</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& non_col_reductions [[buffer(10)]],</div>
|
||||
<div class="line"><a id="l00167" name="l00167"></a><span class="lineno"> 167</span> uint3 gid [[threadgroup_position_in_grid]],</div>
|
||||
<div class="line"><a id="l00168" name="l00168"></a><span class="lineno"> 168</span> uint3 gsize [[threadgroups_per_grid]],</div>
|
||||
<div class="line"><a id="l00169" name="l00169"></a><span class="lineno"> 169</span> uint simd_lane_id [[thread_index_in_simdgroup]],</div>
|
||||
<div class="line"><a id="l00170" name="l00170"></a><span class="lineno"> 170</span> uint simd_group_id [[simdgroup_index_in_threadgroup]]) {</div>
|
||||
<div class="line"><a id="l00171" name="l00171"></a><span class="lineno"> 171</span> Op <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>;</div>
|
||||
<div class="line"><a id="l00172" name="l00172"></a><span class="lineno"> 172</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> n_simdgroups = 8;</div>
|
||||
<div class="line"><a id="l00173" name="l00173"></a><span class="lineno"> 173</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> tgp_size = n_simdgroups * <a class="code hl_variable" href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a>;</div>
|
||||
<div class="line"><a id="l00174" name="l00174"></a><span class="lineno"> 174</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> n_reads = (BM * BN) / tgp_size;</div>
|
||||
<div class="line"><a id="l00175" name="l00175"></a><span class="lineno"> 175</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> n_read_blocks = BN / n_reads;</div>
|
||||
<div class="line"><a id="l00176" name="l00176"></a><span class="lineno"> 176</span> </div>
|
||||
<div class="line"><a id="l00177" name="l00177"></a><span class="lineno"> 177</span> threadgroup U shared_vals[BN * BM];</div>
|
||||
<div class="line"><a id="l00178" name="l00178"></a><span class="lineno"> 178</span> U totals[n_reads];</div>
|
||||
<div class="line"><a id="l00179" name="l00179"></a><span class="lineno"> 179</span> <a class="code hl_struct" href="structlooped__elem__to__loc.html">looped_elem_to_loc<NDIMS></a> loop;</div>
|
||||
<div class="line"><a id="l00180" name="l00180"></a><span class="lineno"> 180</span> <span class="keyword">const</span> device T* row;</div>
|
||||
<div class="line"><a id="l00181" name="l00181"></a><span class="lineno"> 181</span> </div>
|
||||
<div class="line"><a id="l00182" name="l00182"></a><span class="lineno"> 182</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00183" name="l00183"></a><span class="lineno"> 183</span> totals[i] = Op::init;</div>
|
||||
<div class="line"><a id="l00184" name="l00184"></a><span class="lineno"> 184</span> }</div>
|
||||
<div class="line"><a id="l00185" name="l00185"></a><span class="lineno"> 185</span> </div>
|
||||
<div class="line"><a id="l00186" name="l00186"></a><span class="lineno"> 186</span> <span class="keywordtype">short</span> lid = simd_group_id * <a class="code hl_variable" href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a> + simd_lane_id;</div>
|
||||
<div class="line"><a id="l00187" name="l00187"></a><span class="lineno"> 187</span> short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);</div>
|
||||
<div class="line"><a id="l00188" name="l00188"></a><span class="lineno"> 188</span> <span class="keywordtype">size_t</span> column = BN * gid.x + offset.x;</div>
|
||||
<div class="line"><a id="l00189" name="l00189"></a><span class="lineno"> 189</span> <span class="keywordtype">bool</span> safe = column + n_reads <= reduction_stride;</div>
|
||||
<div class="line"><a id="l00190" name="l00190"></a><span class="lineno"> 190</span> </div>
|
||||
<div class="line"><a id="l00191" name="l00191"></a><span class="lineno"> 191</span> <span class="keywordtype">size_t</span> out_idx = gid.y + gsize.y * size_t(gid.z);</div>
|
||||
<div class="line"><a id="l00192" name="l00192"></a><span class="lineno"> 192</span> <span class="keywordtype">size_t</span> in_idx = <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a8fd0c8fc6058e650fc99bca8b6acd7d1">elem_to_loc</a>(out_idx, shape, strides, ndim);</div>
|
||||
<div class="line"><a id="l00193" name="l00193"></a><span class="lineno"> 193</span> in += in_idx + column;</div>
|
||||
<div class="line"><a id="l00194" name="l00194"></a><span class="lineno"> 194</span> </div>
|
||||
<div class="line"><a id="l00195" name="l00195"></a><span class="lineno"> 195</span> <span class="keywordtype">size_t</span> total = non_col_reductions * reduction_size;</div>
|
||||
<div class="line"><a id="l00196" name="l00196"></a><span class="lineno"> 196</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(offset.y, reduce_shape, reduce_strides);</div>
|
||||
<div class="line"><a id="l00197" name="l00197"></a><span class="lineno"> 197</span> <span class="keywordflow">for</span> (<span class="keywordtype">size_t</span> r = offset.y; r < total; r += BM) {</div>
|
||||
<div class="line"><a id="l00198" name="l00198"></a><span class="lineno"> 198</span> row = in + loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#accc6d4957a8aeb38f5062754793b74d2">location</a>(r, reduce_shape, reduce_strides, reduce_ndim);</div>
|
||||
<div class="line"><a id="l00199" name="l00199"></a><span class="lineno"> 199</span> </div>
|
||||
<div class="line"><a id="l00200" name="l00200"></a><span class="lineno"> 200</span> <span class="keywordflow">if</span> (safe) {</div>
|
||||
<div class="line"><a id="l00201" name="l00201"></a><span class="lineno"> 201</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00202" name="l00202"></a><span class="lineno"> 202</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(<span class="keyword">static_cast<</span>U<span class="keyword">></span>(row[i]), totals[i]);</div>
|
||||
<div class="line"><a id="l00203" name="l00203"></a><span class="lineno"> 203</span> }</div>
|
||||
<div class="line"><a id="l00204" name="l00204"></a><span class="lineno"> 204</span> } <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00205" name="l00205"></a><span class="lineno"> 205</span> U vals[n_reads];</div>
|
||||
<div class="line"><a id="l00206" name="l00206"></a><span class="lineno"> 206</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00207" name="l00207"></a><span class="lineno"> 207</span> vals[i] =</div>
|
||||
<div class="line"><a id="l00208" name="l00208"></a><span class="lineno"> 208</span> (column + i < reduction_stride) ? static_cast<U>(row[i]) : <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>.init;</div>
|
||||
<div class="line"><a id="l00209" name="l00209"></a><span class="lineno"> 209</span> }</div>
|
||||
<div class="line"><a id="l00210" name="l00210"></a><span class="lineno"> 210</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00211" name="l00211"></a><span class="lineno"> 211</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(vals[i], totals[i]);</div>
|
||||
<div class="line"><a id="l00212" name="l00212"></a><span class="lineno"> 212</span> }</div>
|
||||
<div class="line"><a id="l00213" name="l00213"></a><span class="lineno"> 213</span> }</div>
|
||||
<div class="line"><a id="l00214" name="l00214"></a><span class="lineno"> 214</span> </div>
|
||||
<div class="line"><a id="l00215" name="l00215"></a><span class="lineno"> 215</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(BM, reduce_shape, reduce_strides);</div>
|
||||
<div class="line"><a id="l00216" name="l00216"></a><span class="lineno"> 216</span> }</div>
|
||||
<div class="line"><a id="l00217" name="l00217"></a><span class="lineno"> 217</span> </div>
|
||||
<div class="line"><a id="l00218" name="l00218"></a><span class="lineno"> 218</span> <span class="comment">// We can use a simd reduction to accumulate across BM so each thread writes</span></div>
|
||||
<div class="line"><a id="l00219" name="l00219"></a><span class="lineno"> 219</span> <span class="comment">// the partial output to SM and then each simdgroup does BN / n_simdgroups</span></div>
|
||||
<div class="line"><a id="l00220" name="l00220"></a><span class="lineno"> 220</span> <span class="comment">// accumulations.</span></div>
|
||||
<div class="line"><a id="l00221" name="l00221"></a><span class="lineno"> 221</span> <span class="keywordflow">if</span> (BM == 32) {</div>
|
||||
<div class="line"><a id="l00222" name="l00222"></a><span class="lineno"> 222</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> n_outputs = BN / n_simdgroups;</div>
|
||||
<div class="line"><a id="l00223" name="l00223"></a><span class="lineno"> 223</span> <span class="keyword">static_assert</span>(</div>
|
||||
<div class="line"><a id="l00224" name="l00224"></a><span class="lineno"> 224</span> BM != 32 || n_outputs == n_reads,</div>
|
||||
<div class="line"><a id="l00225" name="l00225"></a><span class="lineno"> 225</span> <span class="stringliteral">"The tile should be selected such that n_outputs == n_reads"</span>);</div>
|
||||
<div class="line"><a id="l00226" name="l00226"></a><span class="lineno"> 226</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00227" name="l00227"></a><span class="lineno"> 227</span> shared_vals[offset.y * BN + offset.x + i] = totals[i];</div>
|
||||
<div class="line"><a id="l00228" name="l00228"></a><span class="lineno"> 228</span> }</div>
|
||||
<div class="line"><a id="l00229" name="l00229"></a><span class="lineno"> 229</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00230" name="l00230"></a><span class="lineno"> 230</span> short2 out_offset(simd_group_id * n_outputs, simd_lane_id);</div>
|
||||
<div class="line"><a id="l00231" name="l00231"></a><span class="lineno"> 231</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_outputs; i++) {</div>
|
||||
<div class="line"><a id="l00232" name="l00232"></a><span class="lineno"> 232</span> totals[i] =</div>
|
||||
<div class="line"><a id="l00233" name="l00233"></a><span class="lineno"> 233</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);</div>
|
||||
<div class="line"><a id="l00234" name="l00234"></a><span class="lineno"> 234</span> }</div>
|
||||
<div class="line"><a id="l00235" name="l00235"></a><span class="lineno"> 235</span> </div>
|
||||
<div class="line"><a id="l00236" name="l00236"></a><span class="lineno"> 236</span> <span class="comment">// Write the output.</span></div>
|
||||
<div class="line"><a id="l00237" name="l00237"></a><span class="lineno"> 237</span> <span class="keywordflow">if</span> (simd_lane_id == 0) {</div>
|
||||
<div class="line"><a id="l00238" name="l00238"></a><span class="lineno"> 238</span> <span class="keywordtype">size_t</span> out_column = BN * gid.x + out_offset.x;</div>
|
||||
<div class="line"><a id="l00239" name="l00239"></a><span class="lineno"> 239</span> out += out_idx * reduction_stride + out_column;</div>
|
||||
<div class="line"><a id="l00240" name="l00240"></a><span class="lineno"> 240</span> <span class="keywordflow">if</span> (out_column + n_outputs <= reduction_stride) {</div>
|
||||
<div class="line"><a id="l00241" name="l00241"></a><span class="lineno"> 241</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_outputs; i++) {</div>
|
||||
<div class="line"><a id="l00242" name="l00242"></a><span class="lineno"> 242</span> out[i] = totals[i];</div>
|
||||
<div class="line"><a id="l00243" name="l00243"></a><span class="lineno"> 243</span> }</div>
|
||||
<div class="line"><a id="l00244" name="l00244"></a><span class="lineno"> 244</span> } <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00245" name="l00245"></a><span class="lineno"> 245</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; out_column + i < reduction_stride; i++) {</div>
|
||||
<div class="line"><a id="l00246" name="l00246"></a><span class="lineno"> 246</span> out[i] = totals[i];</div>
|
||||
<div class="line"><a id="l00247" name="l00247"></a><span class="lineno"> 247</span> }</div>
|
||||
<div class="line"><a id="l00248" name="l00248"></a><span class="lineno"> 248</span> }</div>
|
||||
<div class="line"><a id="l00249" name="l00249"></a><span class="lineno"> 249</span> }</div>
|
||||
<div class="line"><a id="l00250" name="l00250"></a><span class="lineno"> 250</span> }</div>
|
||||
<div class="line"><a id="l00251" name="l00251"></a><span class="lineno"> 251</span> } <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00252" name="l00252"></a><span class="lineno"> 252</span> U vals[n_reads];</div>
|
||||
<div class="line"><a id="l00253" name="l00253"></a><span class="lineno"> 253</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00254" name="l00254"></a><span class="lineno"> 254</span> vals[i] =</div>
|
||||
<div class="line"><a id="l00255" name="l00255"></a><span class="lineno"> 255</span> (column + i < reduction_stride) ? static_cast<U>(row[i]) : <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>.init;</div>
|
||||
<div class="line"><a id="l00256" name="l00256"></a><span class="lineno"> 256</span> }</div>
|
||||
<div class="line"><a id="l00251" name="l00251"></a><span class="lineno"> 251</span> </div>
|
||||
<div class="line"><a id="l00252" name="l00252"></a><span class="lineno"> 252</span> <span class="comment">// Each thread holds n_reads partial results. We write them all out to shared</span></div>
|
||||
<div class="line"><a id="l00253" name="l00253"></a><span class="lineno"> 253</span> <span class="comment">// memory and threads with offset.y == 0 aggregate the columns and write the</span></div>
|
||||
<div class="line"><a id="l00254" name="l00254"></a><span class="lineno"> 254</span> <span class="comment">// outputs.</span></div>
|
||||
<div class="line"><a id="l00255" name="l00255"></a><span class="lineno"> 255</span> <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00256" name="l00256"></a><span class="lineno"> 256</span> <span class="keywordtype">short</span> x_block = offset.x / n_reads;</div>
|
||||
<div class="line"><a id="l00257" name="l00257"></a><span class="lineno"> 257</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00258" name="l00258"></a><span class="lineno"> 258</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(vals[i], totals[i]);</div>
|
||||
<div class="line"><a id="l00258" name="l00258"></a><span class="lineno"> 258</span> shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];</div>
|
||||
<div class="line"><a id="l00259" name="l00259"></a><span class="lineno"> 259</span> }</div>
|
||||
<div class="line"><a id="l00260" name="l00260"></a><span class="lineno"> 260</span> }</div>
|
||||
<div class="line"><a id="l00261" name="l00261"></a><span class="lineno"> 261</span> </div>
|
||||
<div class="line"><a id="l00262" name="l00262"></a><span class="lineno"> 262</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(BM, reduce_shape, reduce_strides);</div>
|
||||
<div class="line"><a id="l00263" name="l00263"></a><span class="lineno"> 263</span> }</div>
|
||||
<div class="line"><a id="l00264" name="l00264"></a><span class="lineno"> 264</span> </div>
|
||||
<div class="line"><a id="l00265" name="l00265"></a><span class="lineno"> 265</span> <span class="comment">// We can use a simd reduction to accumulate across BM so each thread writes</span></div>
|
||||
<div class="line"><a id="l00266" name="l00266"></a><span class="lineno"> 266</span> <span class="comment">// the partial output to SM and then each simdgroup does BN / n_simdgroups</span></div>
|
||||
<div class="line"><a id="l00267" name="l00267"></a><span class="lineno"> 267</span> <span class="comment">// accumulations.</span></div>
|
||||
<div class="line"><a id="l00268" name="l00268"></a><span class="lineno"> 268</span> <span class="keywordflow">if</span> (BM == 32) {</div>
|
||||
<div class="line"><a id="l00269" name="l00269"></a><span class="lineno"> 269</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> n_outputs = BN / n_simdgroups;</div>
|
||||
<div class="line"><a id="l00270" name="l00270"></a><span class="lineno"> 270</span> <span class="keyword">static_assert</span>(</div>
|
||||
<div class="line"><a id="l00271" name="l00271"></a><span class="lineno"> 271</span> BM != 32 || n_outputs == n_reads,</div>
|
||||
<div class="line"><a id="l00272" name="l00272"></a><span class="lineno"> 272</span> <span class="stringliteral">"The tile should be selected such that n_outputs == n_reads"</span>);</div>
|
||||
<div class="line"><a id="l00273" name="l00273"></a><span class="lineno"> 273</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00274" name="l00274"></a><span class="lineno"> 274</span> shared_vals[offset.y * BN + offset.x + i] = totals[i];</div>
|
||||
<div class="line"><a id="l00275" name="l00275"></a><span class="lineno"> 275</span> }</div>
|
||||
<div class="line"><a id="l00276" name="l00276"></a><span class="lineno"> 276</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00277" name="l00277"></a><span class="lineno"> 277</span> short2 out_offset(simd_group_id * n_outputs, simd_lane_id);</div>
|
||||
<div class="line"><a id="l00278" name="l00278"></a><span class="lineno"> 278</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_outputs; i++) {</div>
|
||||
<div class="line"><a id="l00279" name="l00279"></a><span class="lineno"> 279</span> totals[i] =</div>
|
||||
<div class="line"><a id="l00280" name="l00280"></a><span class="lineno"> 280</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);</div>
|
||||
<div class="line"><a id="l00260" name="l00260"></a><span class="lineno"> 260</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00261" name="l00261"></a><span class="lineno"> 261</span> <span class="keywordflow">if</span> (offset.y == 0) {</div>
|
||||
<div class="line"><a id="l00262" name="l00262"></a><span class="lineno"> 262</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00263" name="l00263"></a><span class="lineno"> 263</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> j = 1; j < BM; j++) {</div>
|
||||
<div class="line"><a id="l00264" name="l00264"></a><span class="lineno"> 264</span> totals[i] =</div>
|
||||
<div class="line"><a id="l00265" name="l00265"></a><span class="lineno"> 265</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);</div>
|
||||
<div class="line"><a id="l00266" name="l00266"></a><span class="lineno"> 266</span> }</div>
|
||||
<div class="line"><a id="l00267" name="l00267"></a><span class="lineno"> 267</span> }</div>
|
||||
<div class="line"><a id="l00268" name="l00268"></a><span class="lineno"> 268</span> }</div>
|
||||
<div class="line"><a id="l00269" name="l00269"></a><span class="lineno"> 269</span> </div>
|
||||
<div class="line"><a id="l00270" name="l00270"></a><span class="lineno"> 270</span> <span class="comment">// Write the output.</span></div>
|
||||
<div class="line"><a id="l00271" name="l00271"></a><span class="lineno"> 271</span> <span class="keywordflow">if</span> (offset.y == 0) {</div>
|
||||
<div class="line"><a id="l00272" name="l00272"></a><span class="lineno"> 272</span> out += out_idx * reduction_stride + column;</div>
|
||||
<div class="line"><a id="l00273" name="l00273"></a><span class="lineno"> 273</span> <span class="keywordflow">if</span> (safe) {</div>
|
||||
<div class="line"><a id="l00274" name="l00274"></a><span class="lineno"> 274</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00275" name="l00275"></a><span class="lineno"> 275</span> out[i] = totals[i];</div>
|
||||
<div class="line"><a id="l00276" name="l00276"></a><span class="lineno"> 276</span> }</div>
|
||||
<div class="line"><a id="l00277" name="l00277"></a><span class="lineno"> 277</span> } <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00278" name="l00278"></a><span class="lineno"> 278</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; column + i < reduction_stride; i++) {</div>
|
||||
<div class="line"><a id="l00279" name="l00279"></a><span class="lineno"> 279</span> out[i] = totals[i];</div>
|
||||
<div class="line"><a id="l00280" name="l00280"></a><span class="lineno"> 280</span> }</div>
|
||||
<div class="line"><a id="l00281" name="l00281"></a><span class="lineno"> 281</span> }</div>
|
||||
<div class="line"><a id="l00282" name="l00282"></a><span class="lineno"> 282</span> </div>
|
||||
<div class="line"><a id="l00283" name="l00283"></a><span class="lineno"> 283</span> <span class="comment">// Write the output.</span></div>
|
||||
<div class="line"><a id="l00284" name="l00284"></a><span class="lineno"> 284</span> <span class="keywordflow">if</span> (simd_lane_id == 0) {</div>
|
||||
<div class="line"><a id="l00285" name="l00285"></a><span class="lineno"> 285</span> <span class="keywordtype">size_t</span> out_column = BN * gid.x + out_offset.x;</div>
|
||||
<div class="line"><a id="l00286" name="l00286"></a><span class="lineno"> 286</span> out += out_idx * reduction_stride + out_column;</div>
|
||||
<div class="line"><a id="l00287" name="l00287"></a><span class="lineno"> 287</span> <span class="keywordflow">if</span> (out_column + n_outputs <= reduction_stride) {</div>
|
||||
<div class="line"><a id="l00288" name="l00288"></a><span class="lineno"> 288</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_outputs; i++) {</div>
|
||||
<div class="line"><a id="l00289" name="l00289"></a><span class="lineno"> 289</span> out[i] = totals[i];</div>
|
||||
<div class="line"><a id="l00290" name="l00290"></a><span class="lineno"> 290</span> }</div>
|
||||
<div class="line"><a id="l00291" name="l00291"></a><span class="lineno"> 291</span> } <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00292" name="l00292"></a><span class="lineno"> 292</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; out_column + i < reduction_stride; i++) {</div>
|
||||
<div class="line"><a id="l00293" name="l00293"></a><span class="lineno"> 293</span> out[i] = totals[i];</div>
|
||||
<div class="line"><a id="l00294" name="l00294"></a><span class="lineno"> 294</span> }</div>
|
||||
<div class="line"><a id="l00295" name="l00295"></a><span class="lineno"> 295</span> }</div>
|
||||
<div class="line"><a id="l00296" name="l00296"></a><span class="lineno"> 296</span> }</div>
|
||||
<div class="line"><a id="l00297" name="l00297"></a><span class="lineno"> 297</span> }</div>
|
||||
<div class="line"><a id="l00298" name="l00298"></a><span class="lineno"> 298</span> </div>
|
||||
<div class="line"><a id="l00299" name="l00299"></a><span class="lineno"> 299</span> <span class="comment">// Each thread holds n_reads partial results. We write them all out to shared</span></div>
|
||||
<div class="line"><a id="l00300" name="l00300"></a><span class="lineno"> 300</span> <span class="comment">// memory and threads with offset.y == 0 aggregate the columns and write the</span></div>
|
||||
<div class="line"><a id="l00301" name="l00301"></a><span class="lineno"> 301</span> <span class="comment">// outputs.</span></div>
|
||||
<div class="line"><a id="l00302" name="l00302"></a><span class="lineno"> 302</span> <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00303" name="l00303"></a><span class="lineno"> 303</span> <span class="keywordtype">short</span> x_block = offset.x / n_reads;</div>
|
||||
<div class="line"><a id="l00304" name="l00304"></a><span class="lineno"> 304</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00305" name="l00305"></a><span class="lineno"> 305</span> shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];</div>
|
||||
<div class="line"><a id="l00306" name="l00306"></a><span class="lineno"> 306</span> }</div>
|
||||
<div class="line"><a id="l00307" name="l00307"></a><span class="lineno"> 307</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00308" name="l00308"></a><span class="lineno"> 308</span> <span class="keywordflow">if</span> (offset.y == 0) {</div>
|
||||
<div class="line"><a id="l00309" name="l00309"></a><span class="lineno"> 309</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00310" name="l00310"></a><span class="lineno"> 310</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> j = 1; j < BM; j++) {</div>
|
||||
<div class="line"><a id="l00311" name="l00311"></a><span class="lineno"> 311</span> totals[i] =</div>
|
||||
<div class="line"><a id="l00312" name="l00312"></a><span class="lineno"> 312</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);</div>
|
||||
<div class="line"><a id="l00313" name="l00313"></a><span class="lineno"> 313</span> }</div>
|
||||
<div class="line"><a id="l00314" name="l00314"></a><span class="lineno"> 314</span> }</div>
|
||||
<div class="line"><a id="l00315" name="l00315"></a><span class="lineno"> 315</span> }</div>
|
||||
<div class="line"><a id="l00316" name="l00316"></a><span class="lineno"> 316</span> </div>
|
||||
<div class="line"><a id="l00317" name="l00317"></a><span class="lineno"> 317</span> <span class="comment">// Write the output.</span></div>
|
||||
<div class="line"><a id="l00318" name="l00318"></a><span class="lineno"> 318</span> <span class="keywordflow">if</span> (offset.y == 0) {</div>
|
||||
<div class="line"><a id="l00319" name="l00319"></a><span class="lineno"> 319</span> out += out_idx * reduction_stride + column;</div>
|
||||
<div class="line"><a id="l00320" name="l00320"></a><span class="lineno"> 320</span> <span class="keywordflow">if</span> (safe) {</div>
|
||||
<div class="line"><a id="l00321" name="l00321"></a><span class="lineno"> 321</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00322" name="l00322"></a><span class="lineno"> 322</span> out[i] = totals[i];</div>
|
||||
<div class="line"><a id="l00323" name="l00323"></a><span class="lineno"> 323</span> }</div>
|
||||
<div class="line"><a id="l00324" name="l00324"></a><span class="lineno"> 324</span> } <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00325" name="l00325"></a><span class="lineno"> 325</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; column + i < reduction_stride; i++) {</div>
|
||||
<div class="line"><a id="l00326" name="l00326"></a><span class="lineno"> 326</span> out[i] = totals[i];</div>
|
||||
<div class="line"><a id="l00327" name="l00327"></a><span class="lineno"> 327</span> }</div>
|
||||
<div class="line"><a id="l00328" name="l00328"></a><span class="lineno"> 328</span> }</div>
|
||||
<div class="line"><a id="l00329" name="l00329"></a><span class="lineno"> 329</span> }</div>
|
||||
<div class="line"><a id="l00330" name="l00330"></a><span class="lineno"> 330</span> }</div>
|
||||
<div class="line"><a id="l00331" name="l00331"></a><span class="lineno"> 331</span>}</div>
|
||||
<div class="line"><a id="l00282" name="l00282"></a><span class="lineno"> 282</span> }</div>
|
||||
<div class="line"><a id="l00283" name="l00283"></a><span class="lineno"> 283</span> }</div>
|
||||
<div class="line"><a id="l00284" name="l00284"></a><span class="lineno"> 284</span>}</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00285" name="l00285"></a><span class="lineno"> 285</span> </div>
|
||||
<div class="line"><a id="l00286" name="l00286"></a><span class="lineno"> 286</span><span class="keyword">template</span> <<span class="keyword">typename</span> T, <span class="keyword">typename</span> U, <span class="keyword">typename</span> Op, <span class="keywordtype">int</span> NDIMS, <span class="keywordtype">int</span> BM, <span class="keywordtype">int</span> BN></div>
|
||||
<div class="foldopen" id="foldopen00287" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00287" name="l00287"></a><span class="lineno"><a class="line" href="reduce__col_8h.html#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d"> 287</a></span>[[kernel]] <span class="keywordtype">void</span> <a class="code hl_function" href="reduce__col_8h.html#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d">col_reduce_2pass</a>(</div>
|
||||
<div class="line"><a id="l00288" name="l00288"></a><span class="lineno"> 288</span> <span class="keyword">const</span> device T* in [[buffer(0)]],</div>
|
||||
<div class="line"><a id="l00289" name="l00289"></a><span class="lineno"> 289</span> device U* out [[buffer(1)]],</div>
|
||||
<div class="line"><a id="l00290" name="l00290"></a><span class="lineno"> 290</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& reduction_size [[buffer(2)]],</div>
|
||||
<div class="line"><a id="l00291" name="l00291"></a><span class="lineno"> 291</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& reduction_stride [[buffer(3)]],</div>
|
||||
<div class="line"><a id="l00292" name="l00292"></a><span class="lineno"> 292</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* shape [[buffer(4)]],</div>
|
||||
<div class="line"><a id="l00293" name="l00293"></a><span class="lineno"> 293</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* strides [[buffer(5)]],</div>
|
||||
<div class="line"><a id="l00294" name="l00294"></a><span class="lineno"> 294</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& ndim [[buffer(6)]],</div>
|
||||
<div class="line"><a id="l00295" name="l00295"></a><span class="lineno"> 295</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* reduce_shape [[buffer(7)]],</div>
|
||||
<div class="line"><a id="l00296" name="l00296"></a><span class="lineno"> 296</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* reduce_strides [[buffer(8)]],</div>
|
||||
<div class="line"><a id="l00297" name="l00297"></a><span class="lineno"> 297</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& reduce_ndim [[buffer(9)]],</div>
|
||||
<div class="line"><a id="l00298" name="l00298"></a><span class="lineno"> 298</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& non_col_reductions [[buffer(10)]],</div>
|
||||
<div class="line"><a id="l00299" name="l00299"></a><span class="lineno"> 299</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& out_size [[buffer(11)]],</div>
|
||||
<div class="line"><a id="l00300" name="l00300"></a><span class="lineno"> 300</span> uint3 gid [[threadgroup_position_in_grid]],</div>
|
||||
<div class="line"><a id="l00301" name="l00301"></a><span class="lineno"> 301</span> uint3 gsize [[threadgroups_per_grid]],</div>
|
||||
<div class="line"><a id="l00302" name="l00302"></a><span class="lineno"> 302</span> uint simd_lane_id [[thread_index_in_simdgroup]],</div>
|
||||
<div class="line"><a id="l00303" name="l00303"></a><span class="lineno"> 303</span> uint simd_group_id [[simdgroup_index_in_threadgroup]]) {</div>
|
||||
<div class="line"><a id="l00304" name="l00304"></a><span class="lineno"> 304</span> Op <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>;</div>
|
||||
<div class="line"><a id="l00305" name="l00305"></a><span class="lineno"> 305</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> n_simdgroups = 8;</div>
|
||||
<div class="line"><a id="l00306" name="l00306"></a><span class="lineno"> 306</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> tgp_size = n_simdgroups * <a class="code hl_variable" href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a>;</div>
|
||||
<div class="line"><a id="l00307" name="l00307"></a><span class="lineno"> 307</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> n_reads = (BM * BN) / tgp_size;</div>
|
||||
<div class="line"><a id="l00308" name="l00308"></a><span class="lineno"> 308</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> n_read_blocks = BN / n_reads;</div>
|
||||
<div class="line"><a id="l00309" name="l00309"></a><span class="lineno"> 309</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> n_outputs = BN / n_simdgroups;</div>
|
||||
<div class="line"><a id="l00310" name="l00310"></a><span class="lineno"> 310</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> outer_blocks = 32;</div>
|
||||
<div class="line"><a id="l00311" name="l00311"></a><span class="lineno"> 311</span> <span class="keyword">static_assert</span>(BM == 32, <span class="stringliteral">"BM should be equal to 32"</span>);</div>
|
||||
<div class="line"><a id="l00312" name="l00312"></a><span class="lineno"> 312</span> </div>
|
||||
<div class="line"><a id="l00313" name="l00313"></a><span class="lineno"> 313</span> threadgroup U shared_vals[BN * BM];</div>
|
||||
<div class="line"><a id="l00314" name="l00314"></a><span class="lineno"> 314</span> U totals[n_reads];</div>
|
||||
<div class="line"><a id="l00315" name="l00315"></a><span class="lineno"> 315</span> <a class="code hl_struct" href="structlooped__elem__to__loc.html">looped_elem_to_loc<NDIMS></a> loop;</div>
|
||||
<div class="line"><a id="l00316" name="l00316"></a><span class="lineno"> 316</span> <span class="keyword">const</span> device T* row;</div>
|
||||
<div class="line"><a id="l00317" name="l00317"></a><span class="lineno"> 317</span> </div>
|
||||
<div class="line"><a id="l00318" name="l00318"></a><span class="lineno"> 318</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00319" name="l00319"></a><span class="lineno"> 319</span> totals[i] = Op::init;</div>
|
||||
<div class="line"><a id="l00320" name="l00320"></a><span class="lineno"> 320</span> }</div>
|
||||
<div class="line"><a id="l00321" name="l00321"></a><span class="lineno"> 321</span> </div>
|
||||
<div class="line"><a id="l00322" name="l00322"></a><span class="lineno"> 322</span> <span class="keywordtype">short</span> lid = simd_group_id * <a class="code hl_variable" href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a> + simd_lane_id;</div>
|
||||
<div class="line"><a id="l00323" name="l00323"></a><span class="lineno"> 323</span> short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);</div>
|
||||
<div class="line"><a id="l00324" name="l00324"></a><span class="lineno"> 324</span> <span class="keywordtype">size_t</span> column = BN * gid.x + offset.x;</div>
|
||||
<div class="line"><a id="l00325" name="l00325"></a><span class="lineno"> 325</span> <span class="keywordtype">bool</span> safe = column + n_reads <= reduction_stride;</div>
|
||||
<div class="line"><a id="l00326" name="l00326"></a><span class="lineno"> 326</span> </div>
|
||||
<div class="line"><a id="l00327" name="l00327"></a><span class="lineno"> 327</span> <span class="keywordtype">size_t</span> full_idx = gid.y + gsize.y * size_t(gid.z);</div>
|
||||
<div class="line"><a id="l00328" name="l00328"></a><span class="lineno"> 328</span> <span class="keywordtype">size_t</span> block_idx = full_idx / out_size;</div>
|
||||
<div class="line"><a id="l00329" name="l00329"></a><span class="lineno"> 329</span> <span class="keywordtype">size_t</span> out_idx = full_idx % out_size;</div>
|
||||
<div class="line"><a id="l00330" name="l00330"></a><span class="lineno"> 330</span> <span class="keywordtype">size_t</span> in_idx = <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a8fd0c8fc6058e650fc99bca8b6acd7d1">elem_to_loc</a>(out_idx, shape, strides, ndim);</div>
|
||||
<div class="line"><a id="l00331" name="l00331"></a><span class="lineno"> 331</span> in += in_idx + column;</div>
|
||||
<div class="line"><a id="l00332" name="l00332"></a><span class="lineno"> 332</span> </div>
|
||||
<div class="line"><a id="l00333" name="l00333"></a><span class="lineno"> 333</span> <span class="keywordtype">size_t</span> total = non_col_reductions * reduction_size;</div>
|
||||
<div class="line"><a id="l00334" name="l00334"></a><span class="lineno"> 334</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(offset.y + block_idx * BM, reduce_shape, reduce_strides);</div>
|
||||
<div class="line"><a id="l00335" name="l00335"></a><span class="lineno"> 335</span> <span class="keywordflow">for</span> (<span class="keywordtype">size_t</span> r = offset.y + block_idx * BM; r < total;</div>
|
||||
<div class="line"><a id="l00336" name="l00336"></a><span class="lineno"> 336</span> r += outer_blocks * BM) {</div>
|
||||
<div class="line"><a id="l00337" name="l00337"></a><span class="lineno"> 337</span> row = in + loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#accc6d4957a8aeb38f5062754793b74d2">location</a>(r, reduce_shape, reduce_strides, reduce_ndim);</div>
|
||||
<div class="line"><a id="l00338" name="l00338"></a><span class="lineno"> 338</span> </div>
|
||||
<div class="line"><a id="l00339" name="l00339"></a><span class="lineno"> 339</span> <span class="keywordflow">if</span> (safe) {</div>
|
||||
<div class="line"><a id="l00340" name="l00340"></a><span class="lineno"> 340</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00341" name="l00341"></a><span class="lineno"> 341</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(<span class="keyword">static_cast<</span>U<span class="keyword">></span>(row[i]), totals[i]);</div>
|
||||
<div class="line"><a id="l00342" name="l00342"></a><span class="lineno"> 342</span> }</div>
|
||||
<div class="line"><a id="l00343" name="l00343"></a><span class="lineno"> 343</span> } <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00344" name="l00344"></a><span class="lineno"> 344</span> U vals[n_reads];</div>
|
||||
<div class="line"><a id="l00345" name="l00345"></a><span class="lineno"> 345</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00346" name="l00346"></a><span class="lineno"> 346</span> vals[i] =</div>
|
||||
<div class="line"><a id="l00347" name="l00347"></a><span class="lineno"> 347</span> (column + i < reduction_stride) ? static_cast<U>(row[i]) : <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>.init;</div>
|
||||
<div class="line"><a id="l00348" name="l00348"></a><span class="lineno"> 348</span> }</div>
|
||||
<div class="line"><a id="l00349" name="l00349"></a><span class="lineno"> 349</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00350" name="l00350"></a><span class="lineno"> 350</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(vals[i], totals[i]);</div>
|
||||
<div class="line"><a id="l00351" name="l00351"></a><span class="lineno"> 351</span> }</div>
|
||||
<div class="line"><a id="l00352" name="l00352"></a><span class="lineno"> 352</span> }</div>
|
||||
<div class="line"><a id="l00353" name="l00353"></a><span class="lineno"> 353</span> </div>
|
||||
<div class="line"><a id="l00354" name="l00354"></a><span class="lineno"> 354</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(outer_blocks * BM, reduce_shape, reduce_strides);</div>
|
||||
<div class="line"><a id="l00355" name="l00355"></a><span class="lineno"> 355</span> }</div>
|
||||
<div class="line"><a id="l00356" name="l00356"></a><span class="lineno"> 356</span> </div>
|
||||
<div class="line"><a id="l00357" name="l00357"></a><span class="lineno"> 357</span> <span class="comment">// We can use a simd reduction to accumulate across BM so each thread writes</span></div>
|
||||
<div class="line"><a id="l00358" name="l00358"></a><span class="lineno"> 358</span> <span class="comment">// the partial output to SM and then each simdgroup does BN / n_simdgroups</span></div>
|
||||
<div class="line"><a id="l00359" name="l00359"></a><span class="lineno"> 359</span> <span class="comment">// accumulations.</span></div>
|
||||
<div class="line"><a id="l00360" name="l00360"></a><span class="lineno"> 360</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_reads; i++) {</div>
|
||||
<div class="line"><a id="l00361" name="l00361"></a><span class="lineno"> 361</span> shared_vals[offset.y * BN + offset.x + i] = totals[i];</div>
|
||||
<div class="line"><a id="l00362" name="l00362"></a><span class="lineno"> 362</span> }</div>
|
||||
<div class="line"><a id="l00363" name="l00363"></a><span class="lineno"> 363</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00364" name="l00364"></a><span class="lineno"> 364</span> short2 out_offset(simd_group_id * n_outputs, simd_lane_id);</div>
|
||||
<div class="line"><a id="l00365" name="l00365"></a><span class="lineno"> 365</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_outputs; i++) {</div>
|
||||
<div class="line"><a id="l00366" name="l00366"></a><span class="lineno"> 366</span> totals[i] =</div>
|
||||
<div class="line"><a id="l00367" name="l00367"></a><span class="lineno"> 367</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);</div>
|
||||
<div class="line"><a id="l00368" name="l00368"></a><span class="lineno"> 368</span> }</div>
|
||||
<div class="line"><a id="l00369" name="l00369"></a><span class="lineno"> 369</span> </div>
|
||||
<div class="line"><a id="l00370" name="l00370"></a><span class="lineno"> 370</span> <span class="comment">// Write the output.</span></div>
|
||||
<div class="line"><a id="l00371" name="l00371"></a><span class="lineno"> 371</span> <span class="keywordflow">if</span> (simd_lane_id == 0) {</div>
|
||||
<div class="line"><a id="l00372" name="l00372"></a><span class="lineno"> 372</span> <span class="keywordtype">size_t</span> out_column = BN * gid.x + out_offset.x;</div>
|
||||
<div class="line"><a id="l00373" name="l00373"></a><span class="lineno"> 373</span> out += full_idx * reduction_stride + out_column;</div>
|
||||
<div class="line"><a id="l00374" name="l00374"></a><span class="lineno"> 374</span> <span class="keywordflow">if</span> (out_column + n_outputs <= reduction_stride) {</div>
|
||||
<div class="line"><a id="l00375" name="l00375"></a><span class="lineno"> 375</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < n_outputs; i++) {</div>
|
||||
<div class="line"><a id="l00376" name="l00376"></a><span class="lineno"> 376</span> out[i] = totals[i];</div>
|
||||
<div class="line"><a id="l00377" name="l00377"></a><span class="lineno"> 377</span> }</div>
|
||||
<div class="line"><a id="l00378" name="l00378"></a><span class="lineno"> 378</span> } <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00379" name="l00379"></a><span class="lineno"> 379</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; out_column + i < reduction_stride; i++) {</div>
|
||||
<div class="line"><a id="l00380" name="l00380"></a><span class="lineno"> 380</span> out[i] = totals[i];</div>
|
||||
<div class="line"><a id="l00381" name="l00381"></a><span class="lineno"> 381</span> }</div>
|
||||
<div class="line"><a id="l00382" name="l00382"></a><span class="lineno"> 382</span> }</div>
|
||||
<div class="line"><a id="l00383" name="l00383"></a><span class="lineno"> 383</span> }</div>
|
||||
<div class="line"><a id="l00384" name="l00384"></a><span class="lineno"> 384</span>}</div>
|
||||
</div>
|
||||
<div class="ttc" id="abackend_2metal_2kernels_2reduction_2ops_8h_html_a515b75d563a93d3c09ee677948dc83e3"><div class="ttname"><a href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a></div><div class="ttdeci">static constant constexpr const uint8_t simd_size</div><div class="ttdef"><b>Definition</b> ops.h:22</div></div>
|
||||
<div class="ttc" id="abackend_2metal_2kernels_2utils_8h_html_a8fd0c8fc6058e650fc99bca8b6acd7d1"><div class="ttname"><a href="backend_2metal_2kernels_2utils_8h.html#a8fd0c8fc6058e650fc99bca8b6acd7d1">elem_to_loc</a></div><div class="ttdeci">METAL_FUNC stride_t elem_to_loc(uint elem, constant const int *shape, constant const stride_t *strides, int ndim)</div><div class="ttdef"><b>Definition</b> utils.h:87</div></div>
|
||||
<div class="ttc" id="acommon_2binary_8h_html_a70228731d29946574b238d21fb4b360c"><div class="ttname"><a href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a></div><div class="ttdeci">Op op</div><div class="ttdef"><b>Definition</b> binary.h:129</div></div>
|
||||
<div class="ttc" id="adefines_8h_html_a2ad505864a2ab786147766900bc18c21"><div class="ttname"><a href="defines_8h.html#a2ad505864a2ab786147766900bc18c21">REDUCE_N_READS</a></div><div class="ttdeci">static constexpr int REDUCE_N_READS</div><div class="ttdef"><b>Definition</b> defines.h:12</div></div>
|
||||
<div class="ttc" id="areduce__col_8h_html_a11bfc6112ae2386ac03f5ea7b7d93385"><div class="ttname"><a href="reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385">col_reduce_looped</a></div><div class="ttdeci">void col_reduce_looped(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)</div><div class="ttdoc">Our approach is the following simple looped approach:</div><div class="ttdef"><b>Definition</b> reduce_col.h:202</div></div>
|
||||
<div class="ttc" id="areduce__col_8h_html_adf7aeb18cd1d5042cf6d9b46b582d8ce"><div class="ttname"><a href="reduce__col_8h.html#adf7aeb18cd1d5042cf6d9b46b582d8ce">col_reduce_small</a></div><div class="ttdeci">void col_reduce_small(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 tsize)</div><div class="ttdef"><b>Definition</b> reduce_col.h:9</div></div>
|
||||
<div class="ttc" id="areduce__col_8h_html_a0e92fc74eeaa8ee2ceb83bafc6eb1d7d"><div class="ttname"><a href="reduce__col_8h.html#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d">col_reduce_2pass</a></div><div class="ttdeci">void col_reduce_2pass(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, const constant size_t &out_size, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)</div><div class="ttdef"><b>Definition</b> reduce_col.h:287</div></div>
|
||||
<div class="ttc" id="areduce__col_8h_html_a11bfc6112ae2386ac03f5ea7b7d93385"><div class="ttname"><a href="reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385">col_reduce_looped</a></div><div class="ttdeci">void col_reduce_looped(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)</div><div class="ttdoc">Our approach is the following simple looped approach:</div><div class="ttdef"><b>Definition</b> reduce_col.h:155</div></div>
|
||||
<div class="ttc" id="areduce__col_8h_html_a5b4f4c4c247ad341ff8d31dcbbbce0eb"><div class="ttname"><a href="reduce__col_8h.html#a5b4f4c4c247ad341ff8d31dcbbbce0eb">col_reduce_longcolumn</a></div><div class="ttdeci">void col_reduce_longcolumn(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, const constant size_t &out_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)</div><div class="ttdef"><b>Definition</b> reduce_col.h:97</div></div>
|
||||
<div class="ttc" id="areduce__col_8h_html_a7c378443a2b6f4d9210db8a21a9ac4f5"><div class="ttname"><a href="reduce__col_8h.html#a7c378443a2b6f4d9210db8a21a9ac4f5">col_reduce_small</a></div><div class="ttdeci">void col_reduce_small(const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)</div><div class="ttdef"><b>Definition</b> reduce_col.h:4</div></div>
|
||||
<div class="ttc" id="astructlooped__elem__to__loc_html"><div class="ttname"><a href="structlooped__elem__to__loc.html">looped_elem_to_loc</a></div><div class="ttdef"><b>Definition</b> utils.h:197</div></div>
|
||||
<div class="ttc" id="astructlooped__elem__to__loc_html_a05558dabba889ee0d80ed4b567d901ca"><div class="ttname"><a href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">looped_elem_to_loc::next</a></div><div class="ttdeci">void next(const constant int *shape, const constant size_t *strides)</div><div class="ttdef"><b>Definition</b> utils.h:202</div></div>
|
||||
<div class="ttc" id="astructlooped__elem__to__loc_html_accc6d4957a8aeb38f5062754793b74d2"><div class="ttname"><a href="structlooped__elem__to__loc.html#accc6d4957a8aeb38f5062754793b74d2">looped_elem_to_loc::location</a></div><div class="ttdeci">offset_t location(offset_t, const constant int *, const constant size_t *, int)</div><div class="ttdef"><b>Definition</b> utils.h:229</div></div>
|
||||
|
15
docs/build/html/sdpa__vector_8h.html
vendored
15
docs/build/html/sdpa__vector_8h.html
vendored
@ -99,13 +99,13 @@ $(function(){ initResizable(false); });
|
||||
<table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a id="func-members" name="func-members"></a>
|
||||
Functions</h2></td></tr>
|
||||
<tr class="memitem:a6f0d7918430064bab910bdaa6c64e927" id="r_a6f0d7918430064bab910bdaa6c64e927"><td class="memTemplParams" colspan="2">template<typename T , int D> </td></tr>
|
||||
<tr class="memitem:a6f0d7918430064bab910bdaa6c64e927"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a6f0d7918430064bab910bdaa6c64e927">sdpa_vector</a> (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:a6f0d7918430064bab910bdaa6c64e927"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a4bf36f16e16c1c62d9b243573568e5ae" id="r_a4bf36f16e16c1c62d9b243573568e5ae"><td class="memTemplParams" colspan="2">template<typename T , int D> </td></tr>
|
||||
<tr class="memitem:a4bf36f16e16c1c62d9b243573568e5ae"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a4bf36f16e16c1c62d9b243573568e5ae">sdpa_vector</a> (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:a4bf36f16e16c1c62d9b243573568e5ae"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table>
|
||||
<h2 class="groupheader">Function Documentation</h2>
|
||||
<a id="a6f0d7918430064bab910bdaa6c64e927" name="a6f0d7918430064bab910bdaa6c64e927"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a6f0d7918430064bab910bdaa6c64e927">◆ </a></span>sdpa_vector()</h2>
|
||||
<a id="a4bf36f16e16c1c62d9b243573568e5ae" name="a4bf36f16e16c1c62d9b243573568e5ae"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a4bf36f16e16c1c62d9b243573568e5ae">◆ </a></span>sdpa_vector()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
@ -147,6 +147,11 @@ template<typename T , int D> </div>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>k_stride</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>v_stride</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
|
203
docs/build/html/sdpa__vector_8h_source.html
vendored
203
docs/build/html/sdpa__vector_8h_source.html
vendored
@ -99,7 +99,7 @@ $(function(){ initResizable(false); });
|
||||
<div class="line"><a id="l00006" name="l00006"></a><span class="lineno"> 6</span> </div>
|
||||
<div class="line"><a id="l00007" name="l00007"></a><span class="lineno"> 7</span><span class="keyword">template</span> <<span class="keyword">typename</span> T, <span class="keywordtype">int</span> D></div>
|
||||
<div class="foldopen" id="foldopen00008" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00008" name="l00008"></a><span class="lineno"><a class="line" href="sdpa__vector_8h.html#a6f0d7918430064bab910bdaa6c64e927"> 8</a></span>[[kernel]] <span class="keywordtype">void</span> <a class="code hl_function" href="sdpa__vector_8h.html#a6f0d7918430064bab910bdaa6c64e927">sdpa_vector</a>(</div>
|
||||
<div class="line"><a id="l00008" name="l00008"></a><span class="lineno"><a class="line" href="sdpa__vector_8h.html#a4bf36f16e16c1c62d9b243573568e5ae"> 8</a></span>[[kernel]] <span class="keywordtype">void</span> <a class="code hl_function" href="sdpa__vector_8h.html#a4bf36f16e16c1c62d9b243573568e5ae">sdpa_vector</a>(</div>
|
||||
<div class="line"><a id="l00009" name="l00009"></a><span class="lineno"> 9</span> <span class="keyword">const</span> device T* queries [[buffer(0)]],</div>
|
||||
<div class="line"><a id="l00010" name="l00010"></a><span class="lineno"> 10</span> <span class="keyword">const</span> device T* keys [[buffer(1)]],</div>
|
||||
<div class="line"><a id="l00011" name="l00011"></a><span class="lineno"> 11</span> <span class="keyword">const</span> device T* values [[buffer(2)]],</div>
|
||||
@ -107,113 +107,114 @@ $(function(){ initResizable(false); });
|
||||
<div class="line"><a id="l00013" name="l00013"></a><span class="lineno"> 13</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& gqa_factor,</div>
|
||||
<div class="line"><a id="l00014" name="l00014"></a><span class="lineno"> 14</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& N,</div>
|
||||
<div class="line"><a id="l00015" name="l00015"></a><span class="lineno"> 15</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& k_stride,</div>
|
||||
<div class="line"><a id="l00016" name="l00016"></a><span class="lineno"> 16</span> <span class="keyword">const</span> constant <span class="keywordtype">float</span>& scale,</div>
|
||||
<div class="line"><a id="l00017" name="l00017"></a><span class="lineno"> 17</span> uint3 tid [[threadgroup_position_in_grid]],</div>
|
||||
<div class="line"><a id="l00018" name="l00018"></a><span class="lineno"> 18</span> uint simd_gid [[simdgroup_index_in_threadgroup]],</div>
|
||||
<div class="line"><a id="l00019" name="l00019"></a><span class="lineno"> 19</span> uint simd_lid [[thread_index_in_simdgroup]]) {</div>
|
||||
<div class="line"><a id="l00020" name="l00020"></a><span class="lineno"> 20</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> BN = 32;</div>
|
||||
<div class="line"><a id="l00021" name="l00021"></a><span class="lineno"> 21</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> BD = 32;</div>
|
||||
<div class="line"><a id="l00022" name="l00022"></a><span class="lineno"> 22</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> elem_per_thread = D / BD;</div>
|
||||
<div class="line"><a id="l00023" name="l00023"></a><span class="lineno"> 23</span> </div>
|
||||
<div class="line"><a id="l00024" name="l00024"></a><span class="lineno"> 24</span> <span class="keyword">const</span> <span class="keywordtype">int</span> stride = BN * D;</div>
|
||||
<div class="line"><a id="l00025" name="l00025"></a><span class="lineno"> 25</span> </div>
|
||||
<div class="line"><a id="l00026" name="l00026"></a><span class="lineno"> 26</span> <span class="keyword">typedef</span> <span class="keywordtype">float</span> U;</div>
|
||||
<div class="line"><a id="l00027" name="l00027"></a><span class="lineno"> 27</span> </div>
|
||||
<div class="line"><a id="l00028" name="l00028"></a><span class="lineno"> 28</span> thread U q[elem_per_thread];</div>
|
||||
<div class="line"><a id="l00029" name="l00029"></a><span class="lineno"> 29</span> thread U k[elem_per_thread];</div>
|
||||
<div class="line"><a id="l00030" name="l00030"></a><span class="lineno"> 30</span> thread U o[elem_per_thread];</div>
|
||||
<div class="line"><a id="l00031" name="l00031"></a><span class="lineno"> 31</span> </div>
|
||||
<div class="line"><a id="l00032" name="l00032"></a><span class="lineno"> 32</span> threadgroup U outputs[BN * BD];</div>
|
||||
<div class="line"><a id="l00033" name="l00033"></a><span class="lineno"> 33</span> threadgroup U max_scores[BN];</div>
|
||||
<div class="line"><a id="l00034" name="l00034"></a><span class="lineno"> 34</span> threadgroup U sum_exp_scores[BN];</div>
|
||||
<div class="line"><a id="l00035" name="l00035"></a><span class="lineno"> 35</span> </div>
|
||||
<div class="line"><a id="l00036" name="l00036"></a><span class="lineno"> 36</span> <span class="comment">// Adjust positions</span></div>
|
||||
<div class="line"><a id="l00037" name="l00037"></a><span class="lineno"> 37</span> <span class="keyword">const</span> <span class="keywordtype">int</span> head_idx = tid.y;</div>
|
||||
<div class="line"><a id="l00038" name="l00038"></a><span class="lineno"> 38</span> <span class="keyword">const</span> <span class="keywordtype">int</span> kv_head_idx = head_idx / gqa_factor;</div>
|
||||
<div class="line"><a id="l00039" name="l00039"></a><span class="lineno"> 39</span> queries += head_idx * D + simd_lid * elem_per_thread;</div>
|
||||
<div class="line"><a id="l00040" name="l00040"></a><span class="lineno"> 40</span> keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;</div>
|
||||
<div class="line"><a id="l00041" name="l00041"></a><span class="lineno"> 41</span> values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;</div>
|
||||
<div class="line"><a id="l00042" name="l00042"></a><span class="lineno"> 42</span> out += head_idx * D + simd_gid * elem_per_thread;</div>
|
||||
<div class="line"><a id="l00043" name="l00043"></a><span class="lineno"> 43</span> </div>
|
||||
<div class="line"><a id="l00044" name="l00044"></a><span class="lineno"> 44</span> <span class="comment">// Read the query and 0 the output accumulator</span></div>
|
||||
<div class="line"><a id="l00045" name="l00045"></a><span class="lineno"> 45</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < elem_per_thread; i++) {</div>
|
||||
<div class="line"><a id="l00046" name="l00046"></a><span class="lineno"> 46</span> q[i] = <span class="keyword">static_cast<</span>U<span class="keyword">></span>(scale) * queries[i];</div>
|
||||
<div class="line"><a id="l00047" name="l00047"></a><span class="lineno"> 47</span> }</div>
|
||||
<div class="line"><a id="l00048" name="l00048"></a><span class="lineno"> 48</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < elem_per_thread; i++) {</div>
|
||||
<div class="line"><a id="l00049" name="l00049"></a><span class="lineno"> 49</span> o[i] = 0;</div>
|
||||
<div class="line"><a id="l00050" name="l00050"></a><span class="lineno"> 50</span> }</div>
|
||||
<div class="line"><a id="l00051" name="l00051"></a><span class="lineno"> 51</span> </div>
|
||||
<div class="line"><a id="l00052" name="l00052"></a><span class="lineno"> 52</span> U max_score = -INFINITY;</div>
|
||||
<div class="line"><a id="l00053" name="l00053"></a><span class="lineno"> 53</span> U sum_exp_score = 0;</div>
|
||||
<div class="line"><a id="l00054" name="l00054"></a><span class="lineno"> 54</span> </div>
|
||||
<div class="line"><a id="l00055" name="l00055"></a><span class="lineno"> 55</span> <span class="comment">// For each key</span></div>
|
||||
<div class="line"><a id="l00056" name="l00056"></a><span class="lineno"> 56</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = simd_gid; i < N; i += BN) {</div>
|
||||
<div class="line"><a id="l00057" name="l00057"></a><span class="lineno"> 57</span> <span class="comment">// Read the key</span></div>
|
||||
<div class="line"><a id="l00058" name="l00058"></a><span class="lineno"> 58</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < elem_per_thread; i++) {</div>
|
||||
<div class="line"><a id="l00059" name="l00059"></a><span class="lineno"> 59</span> k[i] = keys[i];</div>
|
||||
<div class="line"><a id="l00060" name="l00060"></a><span class="lineno"> 60</span> }</div>
|
||||
<div class="line"><a id="l00061" name="l00061"></a><span class="lineno"> 61</span> </div>
|
||||
<div class="line"><a id="l00062" name="l00062"></a><span class="lineno"> 62</span> <span class="comment">// Compute the i-th score</span></div>
|
||||
<div class="line"><a id="l00063" name="l00063"></a><span class="lineno"> 63</span> U score = 0;</div>
|
||||
<div class="line"><a id="l00064" name="l00064"></a><span class="lineno"> 64</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < elem_per_thread; i++) {</div>
|
||||
<div class="line"><a id="l00065" name="l00065"></a><span class="lineno"> 65</span> score += q[i] * k[i];</div>
|
||||
<div class="line"><a id="l00066" name="l00066"></a><span class="lineno"> 66</span> }</div>
|
||||
<div class="line"><a id="l00067" name="l00067"></a><span class="lineno"> 67</span> score = <a class="code hl_function" href="namespacemetal.html#a85181e37a00cb4a4217f1bb25389bce5">simd_sum</a>(score);</div>
|
||||
<div class="line"><a id="l00068" name="l00068"></a><span class="lineno"> 68</span> </div>
|
||||
<div class="line"><a id="l00069" name="l00069"></a><span class="lineno"> 69</span> <span class="comment">// Update the accumulators</span></div>
|
||||
<div class="line"><a id="l00070" name="l00070"></a><span class="lineno"> 70</span> U new_max = <a class="code hl_function" href="namespacemetal.html#a853c80479ab2264d9c4587c7bcac767b">max</a>(max_score, score);</div>
|
||||
<div class="line"><a id="l00071" name="l00071"></a><span class="lineno"> 71</span> U factor = <a class="code hl_function" href="namespacemetal_1_1fast.html#ad3dbd387b63373c29e3449609f763ede">fast::exp</a>(max_score - new_max);</div>
|
||||
<div class="line"><a id="l00072" name="l00072"></a><span class="lineno"> 72</span> U exp_score = <a class="code hl_function" href="namespacemetal_1_1fast.html#ad3dbd387b63373c29e3449609f763ede">fast::exp</a>(score - new_max);</div>
|
||||
<div class="line"><a id="l00073" name="l00073"></a><span class="lineno"> 73</span> </div>
|
||||
<div class="line"><a id="l00074" name="l00074"></a><span class="lineno"> 74</span> max_score = new_max;</div>
|
||||
<div class="line"><a id="l00075" name="l00075"></a><span class="lineno"> 75</span> sum_exp_score = sum_exp_score * factor + exp_score;</div>
|
||||
<div class="line"><a id="l00076" name="l00076"></a><span class="lineno"> 76</span> </div>
|
||||
<div class="line"><a id="l00077" name="l00077"></a><span class="lineno"> 77</span> <span class="comment">// Update the output accumulator</span></div>
|
||||
<div class="line"><a id="l00078" name="l00078"></a><span class="lineno"> 78</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < elem_per_thread; i++) {</div>
|
||||
<div class="line"><a id="l00079" name="l00079"></a><span class="lineno"> 79</span> o[i] = o[i] * factor + exp_score * values[i];</div>
|
||||
<div class="line"><a id="l00080" name="l00080"></a><span class="lineno"> 80</span> }</div>
|
||||
<div class="line"><a id="l00081" name="l00081"></a><span class="lineno"> 81</span> </div>
|
||||
<div class="line"><a id="l00082" name="l00082"></a><span class="lineno"> 82</span> <span class="comment">// Move the pointers to the next kv</span></div>
|
||||
<div class="line"><a id="l00083" name="l00083"></a><span class="lineno"> 83</span> keys += stride;</div>
|
||||
<div class="line"><a id="l00084" name="l00084"></a><span class="lineno"> 84</span> values += stride;</div>
|
||||
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"> 85</span> }</div>
|
||||
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"> 86</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00087" name="l00087"></a><span class="lineno"> 87</span> </div>
|
||||
<div class="line"><a id="l00088" name="l00088"></a><span class="lineno"> 88</span> <span class="comment">// Each thread has a partial part of the output so we need to combine them.</span></div>
|
||||
<div class="line"><a id="l00089" name="l00089"></a><span class="lineno"> 89</span> </div>
|
||||
<div class="line"><a id="l00090" name="l00090"></a><span class="lineno"> 90</span> <span class="comment">// First let's communicate the max and sum_exp</span></div>
|
||||
<div class="line"><a id="l00091" name="l00091"></a><span class="lineno"> 91</span> <span class="keywordflow">if</span> (simd_lid == 0) {</div>
|
||||
<div class="line"><a id="l00092" name="l00092"></a><span class="lineno"> 92</span> max_scores[simd_gid] = max_score;</div>
|
||||
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"> 93</span> sum_exp_scores[simd_gid] = sum_exp_score;</div>
|
||||
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"> 94</span> }</div>
|
||||
<div class="line"><a id="l00095" name="l00095"></a><span class="lineno"> 95</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00096" name="l00096"></a><span class="lineno"> 96</span> max_score = max_scores[simd_lid];</div>
|
||||
<div class="line"><a id="l00097" name="l00097"></a><span class="lineno"> 97</span> U new_max = <a class="code hl_function" href="namespacemetal.html#a048cad0aca52cb737ebf103e76bd1c49">simd_max</a>(max_score);</div>
|
||||
<div class="line"><a id="l00098" name="l00098"></a><span class="lineno"> 98</span> U factor = <a class="code hl_function" href="namespacemetal_1_1fast.html#ad3dbd387b63373c29e3449609f763ede">fast::exp</a>(max_score - new_max);</div>
|
||||
<div class="line"><a id="l00099" name="l00099"></a><span class="lineno"> 99</span> sum_exp_score = <a class="code hl_function" href="namespacemetal.html#a85181e37a00cb4a4217f1bb25389bce5">simd_sum</a>(sum_exp_scores[simd_lid] * factor);</div>
|
||||
<div class="line"><a id="l00100" name="l00100"></a><span class="lineno"> 100</span> </div>
|
||||
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"> 101</span> <span class="comment">// Now we need to aggregate all the outputs</span></div>
|
||||
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"> 102</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < elem_per_thread; i++) {</div>
|
||||
<div class="line"><a id="l00103" name="l00103"></a><span class="lineno"> 103</span> outputs[simd_lid * BD + simd_gid] = o[i];</div>
|
||||
<div class="line"><a id="l00104" name="l00104"></a><span class="lineno"> 104</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00105" name="l00105"></a><span class="lineno"> 105</span> o[i] = <a class="code hl_function" href="namespacemetal.html#a85181e37a00cb4a4217f1bb25389bce5">simd_sum</a>(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;</div>
|
||||
<div class="line"><a id="l00106" name="l00106"></a><span class="lineno"> 106</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00107" name="l00107"></a><span class="lineno"> 107</span> }</div>
|
||||
<div class="line"><a id="l00108" name="l00108"></a><span class="lineno"> 108</span> </div>
|
||||
<div class="line"><a id="l00109" name="l00109"></a><span class="lineno"> 109</span> <span class="comment">// And write the output</span></div>
|
||||
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"> 110</span> <span class="keywordflow">if</span> (simd_lid == 0) {</div>
|
||||
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"> 111</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < elem_per_thread; i++) {</div>
|
||||
<div class="line"><a id="l00112" name="l00112"></a><span class="lineno"> 112</span> out[i] = <span class="keyword">static_cast<</span>T<span class="keyword">></span>(o[i]);</div>
|
||||
<div class="line"><a id="l00113" name="l00113"></a><span class="lineno"> 113</span> }</div>
|
||||
<div class="line"><a id="l00016" name="l00016"></a><span class="lineno"> 16</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>& v_stride,</div>
|
||||
<div class="line"><a id="l00017" name="l00017"></a><span class="lineno"> 17</span> <span class="keyword">const</span> constant <span class="keywordtype">float</span>& scale,</div>
|
||||
<div class="line"><a id="l00018" name="l00018"></a><span class="lineno"> 18</span> uint3 tid [[threadgroup_position_in_grid]],</div>
|
||||
<div class="line"><a id="l00019" name="l00019"></a><span class="lineno"> 19</span> uint simd_gid [[simdgroup_index_in_threadgroup]],</div>
|
||||
<div class="line"><a id="l00020" name="l00020"></a><span class="lineno"> 20</span> uint simd_lid [[thread_index_in_simdgroup]]) {</div>
|
||||
<div class="line"><a id="l00021" name="l00021"></a><span class="lineno"> 21</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> BN = 32;</div>
|
||||
<div class="line"><a id="l00022" name="l00022"></a><span class="lineno"> 22</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> BD = 32;</div>
|
||||
<div class="line"><a id="l00023" name="l00023"></a><span class="lineno"> 23</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> elem_per_thread = D / BD;</div>
|
||||
<div class="line"><a id="l00024" name="l00024"></a><span class="lineno"> 24</span> </div>
|
||||
<div class="line"><a id="l00025" name="l00025"></a><span class="lineno"> 25</span> <span class="keyword">const</span> <span class="keywordtype">int</span> stride = BN * D;</div>
|
||||
<div class="line"><a id="l00026" name="l00026"></a><span class="lineno"> 26</span> </div>
|
||||
<div class="line"><a id="l00027" name="l00027"></a><span class="lineno"> 27</span> <span class="keyword">typedef</span> <span class="keywordtype">float</span> U;</div>
|
||||
<div class="line"><a id="l00028" name="l00028"></a><span class="lineno"> 28</span> </div>
|
||||
<div class="line"><a id="l00029" name="l00029"></a><span class="lineno"> 29</span> thread U q[elem_per_thread];</div>
|
||||
<div class="line"><a id="l00030" name="l00030"></a><span class="lineno"> 30</span> thread U k[elem_per_thread];</div>
|
||||
<div class="line"><a id="l00031" name="l00031"></a><span class="lineno"> 31</span> thread U o[elem_per_thread];</div>
|
||||
<div class="line"><a id="l00032" name="l00032"></a><span class="lineno"> 32</span> </div>
|
||||
<div class="line"><a id="l00033" name="l00033"></a><span class="lineno"> 33</span> threadgroup U outputs[BN * BD];</div>
|
||||
<div class="line"><a id="l00034" name="l00034"></a><span class="lineno"> 34</span> threadgroup U max_scores[BN];</div>
|
||||
<div class="line"><a id="l00035" name="l00035"></a><span class="lineno"> 35</span> threadgroup U sum_exp_scores[BN];</div>
|
||||
<div class="line"><a id="l00036" name="l00036"></a><span class="lineno"> 36</span> </div>
|
||||
<div class="line"><a id="l00037" name="l00037"></a><span class="lineno"> 37</span> <span class="comment">// Adjust positions</span></div>
|
||||
<div class="line"><a id="l00038" name="l00038"></a><span class="lineno"> 38</span> <span class="keyword">const</span> <span class="keywordtype">int</span> head_idx = tid.y;</div>
|
||||
<div class="line"><a id="l00039" name="l00039"></a><span class="lineno"> 39</span> <span class="keyword">const</span> <span class="keywordtype">int</span> kv_head_idx = head_idx / gqa_factor;</div>
|
||||
<div class="line"><a id="l00040" name="l00040"></a><span class="lineno"> 40</span> queries += head_idx * D + simd_lid * elem_per_thread;</div>
|
||||
<div class="line"><a id="l00041" name="l00041"></a><span class="lineno"> 41</span> keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;</div>
|
||||
<div class="line"><a id="l00042" name="l00042"></a><span class="lineno"> 42</span> values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;</div>
|
||||
<div class="line"><a id="l00043" name="l00043"></a><span class="lineno"> 43</span> out += head_idx * D + simd_gid * elem_per_thread;</div>
|
||||
<div class="line"><a id="l00044" name="l00044"></a><span class="lineno"> 44</span> </div>
|
||||
<div class="line"><a id="l00045" name="l00045"></a><span class="lineno"> 45</span> <span class="comment">// Read the query and 0 the output accumulator</span></div>
|
||||
<div class="line"><a id="l00046" name="l00046"></a><span class="lineno"> 46</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < elem_per_thread; i++) {</div>
|
||||
<div class="line"><a id="l00047" name="l00047"></a><span class="lineno"> 47</span> q[i] = <span class="keyword">static_cast<</span>U<span class="keyword">></span>(scale) * queries[i];</div>
|
||||
<div class="line"><a id="l00048" name="l00048"></a><span class="lineno"> 48</span> }</div>
|
||||
<div class="line"><a id="l00049" name="l00049"></a><span class="lineno"> 49</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < elem_per_thread; i++) {</div>
|
||||
<div class="line"><a id="l00050" name="l00050"></a><span class="lineno"> 50</span> o[i] = 0;</div>
|
||||
<div class="line"><a id="l00051" name="l00051"></a><span class="lineno"> 51</span> }</div>
|
||||
<div class="line"><a id="l00052" name="l00052"></a><span class="lineno"> 52</span> </div>
|
||||
<div class="line"><a id="l00053" name="l00053"></a><span class="lineno"> 53</span> U max_score = -INFINITY;</div>
|
||||
<div class="line"><a id="l00054" name="l00054"></a><span class="lineno"> 54</span> U sum_exp_score = 0;</div>
|
||||
<div class="line"><a id="l00055" name="l00055"></a><span class="lineno"> 55</span> </div>
|
||||
<div class="line"><a id="l00056" name="l00056"></a><span class="lineno"> 56</span> <span class="comment">// For each key</span></div>
|
||||
<div class="line"><a id="l00057" name="l00057"></a><span class="lineno"> 57</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = simd_gid; i < N; i += BN) {</div>
|
||||
<div class="line"><a id="l00058" name="l00058"></a><span class="lineno"> 58</span> <span class="comment">// Read the key</span></div>
|
||||
<div class="line"><a id="l00059" name="l00059"></a><span class="lineno"> 59</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < elem_per_thread; i++) {</div>
|
||||
<div class="line"><a id="l00060" name="l00060"></a><span class="lineno"> 60</span> k[i] = keys[i];</div>
|
||||
<div class="line"><a id="l00061" name="l00061"></a><span class="lineno"> 61</span> }</div>
|
||||
<div class="line"><a id="l00062" name="l00062"></a><span class="lineno"> 62</span> </div>
|
||||
<div class="line"><a id="l00063" name="l00063"></a><span class="lineno"> 63</span> <span class="comment">// Compute the i-th score</span></div>
|
||||
<div class="line"><a id="l00064" name="l00064"></a><span class="lineno"> 64</span> U score = 0;</div>
|
||||
<div class="line"><a id="l00065" name="l00065"></a><span class="lineno"> 65</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < elem_per_thread; i++) {</div>
|
||||
<div class="line"><a id="l00066" name="l00066"></a><span class="lineno"> 66</span> score += q[i] * k[i];</div>
|
||||
<div class="line"><a id="l00067" name="l00067"></a><span class="lineno"> 67</span> }</div>
|
||||
<div class="line"><a id="l00068" name="l00068"></a><span class="lineno"> 68</span> score = <a class="code hl_function" href="namespacemetal.html#a85181e37a00cb4a4217f1bb25389bce5">simd_sum</a>(score);</div>
|
||||
<div class="line"><a id="l00069" name="l00069"></a><span class="lineno"> 69</span> </div>
|
||||
<div class="line"><a id="l00070" name="l00070"></a><span class="lineno"> 70</span> <span class="comment">// Update the accumulators</span></div>
|
||||
<div class="line"><a id="l00071" name="l00071"></a><span class="lineno"> 71</span> U new_max = <a class="code hl_function" href="namespacemetal.html#a853c80479ab2264d9c4587c7bcac767b">max</a>(max_score, score);</div>
|
||||
<div class="line"><a id="l00072" name="l00072"></a><span class="lineno"> 72</span> U factor = <a class="code hl_function" href="namespacemetal_1_1fast.html#ad3dbd387b63373c29e3449609f763ede">fast::exp</a>(max_score - new_max);</div>
|
||||
<div class="line"><a id="l00073" name="l00073"></a><span class="lineno"> 73</span> U exp_score = <a class="code hl_function" href="namespacemetal_1_1fast.html#ad3dbd387b63373c29e3449609f763ede">fast::exp</a>(score - new_max);</div>
|
||||
<div class="line"><a id="l00074" name="l00074"></a><span class="lineno"> 74</span> </div>
|
||||
<div class="line"><a id="l00075" name="l00075"></a><span class="lineno"> 75</span> max_score = new_max;</div>
|
||||
<div class="line"><a id="l00076" name="l00076"></a><span class="lineno"> 76</span> sum_exp_score = sum_exp_score * factor + exp_score;</div>
|
||||
<div class="line"><a id="l00077" name="l00077"></a><span class="lineno"> 77</span> </div>
|
||||
<div class="line"><a id="l00078" name="l00078"></a><span class="lineno"> 78</span> <span class="comment">// Update the output accumulator</span></div>
|
||||
<div class="line"><a id="l00079" name="l00079"></a><span class="lineno"> 79</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < elem_per_thread; i++) {</div>
|
||||
<div class="line"><a id="l00080" name="l00080"></a><span class="lineno"> 80</span> o[i] = o[i] * factor + exp_score * values[i];</div>
|
||||
<div class="line"><a id="l00081" name="l00081"></a><span class="lineno"> 81</span> }</div>
|
||||
<div class="line"><a id="l00082" name="l00082"></a><span class="lineno"> 82</span> </div>
|
||||
<div class="line"><a id="l00083" name="l00083"></a><span class="lineno"> 83</span> <span class="comment">// Move the pointers to the next kv</span></div>
|
||||
<div class="line"><a id="l00084" name="l00084"></a><span class="lineno"> 84</span> keys += stride;</div>
|
||||
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"> 85</span> values += stride;</div>
|
||||
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"> 86</span> }</div>
|
||||
<div class="line"><a id="l00087" name="l00087"></a><span class="lineno"> 87</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00088" name="l00088"></a><span class="lineno"> 88</span> </div>
|
||||
<div class="line"><a id="l00089" name="l00089"></a><span class="lineno"> 89</span> <span class="comment">// Each thread has a partial part of the output so we need to combine them.</span></div>
|
||||
<div class="line"><a id="l00090" name="l00090"></a><span class="lineno"> 90</span> </div>
|
||||
<div class="line"><a id="l00091" name="l00091"></a><span class="lineno"> 91</span> <span class="comment">// First let's communicate the max and sum_exp</span></div>
|
||||
<div class="line"><a id="l00092" name="l00092"></a><span class="lineno"> 92</span> <span class="keywordflow">if</span> (simd_lid == 0) {</div>
|
||||
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"> 93</span> max_scores[simd_gid] = max_score;</div>
|
||||
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"> 94</span> sum_exp_scores[simd_gid] = sum_exp_score;</div>
|
||||
<div class="line"><a id="l00095" name="l00095"></a><span class="lineno"> 95</span> }</div>
|
||||
<div class="line"><a id="l00096" name="l00096"></a><span class="lineno"> 96</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00097" name="l00097"></a><span class="lineno"> 97</span> max_score = max_scores[simd_lid];</div>
|
||||
<div class="line"><a id="l00098" name="l00098"></a><span class="lineno"> 98</span> U new_max = <a class="code hl_function" href="namespacemetal.html#a048cad0aca52cb737ebf103e76bd1c49">simd_max</a>(max_score);</div>
|
||||
<div class="line"><a id="l00099" name="l00099"></a><span class="lineno"> 99</span> U factor = <a class="code hl_function" href="namespacemetal_1_1fast.html#ad3dbd387b63373c29e3449609f763ede">fast::exp</a>(max_score - new_max);</div>
|
||||
<div class="line"><a id="l00100" name="l00100"></a><span class="lineno"> 100</span> sum_exp_score = <a class="code hl_function" href="namespacemetal.html#a85181e37a00cb4a4217f1bb25389bce5">simd_sum</a>(sum_exp_scores[simd_lid] * factor);</div>
|
||||
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"> 101</span> </div>
|
||||
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"> 102</span> <span class="comment">// Now we need to aggregate all the outputs</span></div>
|
||||
<div class="line"><a id="l00103" name="l00103"></a><span class="lineno"> 103</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < elem_per_thread; i++) {</div>
|
||||
<div class="line"><a id="l00104" name="l00104"></a><span class="lineno"> 104</span> outputs[simd_lid * BD + simd_gid] = o[i];</div>
|
||||
<div class="line"><a id="l00105" name="l00105"></a><span class="lineno"> 105</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00106" name="l00106"></a><span class="lineno"> 106</span> o[i] = <a class="code hl_function" href="namespacemetal.html#a85181e37a00cb4a4217f1bb25389bce5">simd_sum</a>(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;</div>
|
||||
<div class="line"><a id="l00107" name="l00107"></a><span class="lineno"> 107</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00108" name="l00108"></a><span class="lineno"> 108</span> }</div>
|
||||
<div class="line"><a id="l00109" name="l00109"></a><span class="lineno"> 109</span> </div>
|
||||
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"> 110</span> <span class="comment">// And write the output</span></div>
|
||||
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"> 111</span> <span class="keywordflow">if</span> (simd_lid == 0) {</div>
|
||||
<div class="line"><a id="l00112" name="l00112"></a><span class="lineno"> 112</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < elem_per_thread; i++) {</div>
|
||||
<div class="line"><a id="l00113" name="l00113"></a><span class="lineno"> 113</span> out[i] = <span class="keyword">static_cast<</span>T<span class="keyword">></span>(o[i]);</div>
|
||||
<div class="line"><a id="l00114" name="l00114"></a><span class="lineno"> 114</span> }</div>
|
||||
<div class="line"><a id="l00115" name="l00115"></a><span class="lineno"> 115</span>}</div>
|
||||
<div class="line"><a id="l00115" name="l00115"></a><span class="lineno"> 115</span> }</div>
|
||||
<div class="line"><a id="l00116" name="l00116"></a><span class="lineno"> 116</span>}</div>
|
||||
</div>
|
||||
<div class="ttc" id="anamespacemetal_1_1fast_html_ad3dbd387b63373c29e3449609f763ede"><div class="ttname"><a href="namespacemetal_1_1fast.html#ad3dbd387b63373c29e3449609f763ede">metal::fast::exp</a></div><div class="ttdeci">METAL_FUNC bfloat16_t exp(bfloat16_t x)</div><div class="ttdef"><b>Definition</b> bf16_math.h:242</div></div>
|
||||
<div class="ttc" id="anamespacemetal_html"><div class="ttname"><a href="namespacemetal.html">metal</a></div><div class="ttdef"><b>Definition</b> bf16.h:265</div></div>
|
||||
<div class="ttc" id="anamespacemetal_html_a048cad0aca52cb737ebf103e76bd1c49"><div class="ttname"><a href="namespacemetal.html#a048cad0aca52cb737ebf103e76bd1c49">metal::simd_max</a></div><div class="ttdeci">METAL_FUNC bfloat16_t simd_max(bfloat16_t data)</div><div class="ttdef"><b>Definition</b> bf16_math.h:392</div></div>
|
||||
<div class="ttc" id="anamespacemetal_html_a85181e37a00cb4a4217f1bb25389bce5"><div class="ttname"><a href="namespacemetal.html#a85181e37a00cb4a4217f1bb25389bce5">metal::simd_sum</a></div><div class="ttdeci">METAL_FUNC bfloat16_t simd_sum(bfloat16_t data)</div><div class="ttdef"><b>Definition</b> bf16_math.h:392</div></div>
|
||||
<div class="ttc" id="anamespacemetal_html_a853c80479ab2264d9c4587c7bcac767b"><div class="ttname"><a href="namespacemetal.html#a853c80479ab2264d9c4587c7bcac767b">metal::max</a></div><div class="ttdeci">METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)</div><div class="ttdef"><b>Definition</b> bf16_math.h:234</div></div>
|
||||
<div class="ttc" id="asdpa__vector_8h_html_a6f0d7918430064bab910bdaa6c64e927"><div class="ttname"><a href="sdpa__vector_8h.html#a6f0d7918430064bab910bdaa6c64e927">sdpa_vector</a></div><div class="ttdeci">void sdpa_vector(const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid)</div><div class="ttdef"><b>Definition</b> sdpa_vector.h:8</div></div>
|
||||
<div class="ttc" id="asdpa__vector_8h_html_a4bf36f16e16c1c62d9b243573568e5ae"><div class="ttname"><a href="sdpa__vector_8h.html#a4bf36f16e16c1c62d9b243573568e5ae">sdpa_vector</a></div><div class="ttdeci">void sdpa_vector(const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid)</div><div class="ttdef"><b>Definition</b> sdpa_vector.h:8</div></div>
|
||||
</div><!-- fragment --></div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
<hr class="footer"/><address class="footer"><small>
|
||||
|
2
docs/build/html/search/all_1.js
vendored
2
docs/build/html/search/all_1.js
vendored
@ -38,7 +38,7 @@ var searchData=
|
||||
['all_35',['all',['../group__ops.html#ga3b1b90ef1275ca17655b6d7f25d3ee68',1,'mlx::core::all(const array &a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga3689e12e8f42dadb4cbe2b07dc4099f4',1,'mlx::core::all(const array &a, StreamOrDevice s={})'],['../group__ops.html#gac0919c6ba53aea35a7683dea7e9a9a59',1,'mlx::core::all(const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#gae2d5fcc5b62d673cca76c08b7b4afbbc',1,'mlx::core::all(const array &a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
|
||||
['all_5fgather_36',['all_gather',['../namespacemlx_1_1core_1_1distributed_1_1detail.html#aeb5a1726358213bc75756506f7b54d04',1,'mlx::core::distributed::detail::all_gather()'],['../namespacemlx_1_1core_1_1distributed.html#a82ef5e8cc7ac62cd228e51b1c1b77cb7',1,'mlx::core::distributed::all_gather()']]],
|
||||
['all_5freduce_37',['all_reduce',['../reduce__all_8h.html#a99ef48ae72b3e715c5f4d7ea07cd213d',1,'reduce_all.h']]],
|
||||
['all_5freduce_5fdispatch_38',['all_reduce_dispatch',['../namespacemlx_1_1core.html#af7b7ca7c6aa87558d9f98cee5c7a99a8',1,'mlx::core']]],
|
||||
['all_5freduce_5fdispatch_38',['all_reduce_dispatch',['../namespacemlx_1_1core.html#a3ab0fd997d9a35782106ff083a72e098',1,'mlx::core']]],
|
||||
['all_5fsum_39',['all_sum',['../namespacemlx_1_1core_1_1distributed_1_1detail.html#aa1d225b25f7b6426c48c5e35860ee960',1,'mlx::core::distributed::detail::all_sum()'],['../namespacemlx_1_1core_1_1distributed.html#a67ccb1a5445fc6f5db49dd36a15e5980',1,'mlx::core::distributed::all_sum()']]],
|
||||
['allclose_40',['allclose',['../group__ops.html#gaf0cd4257de7542daf9faf5e605e31020',1,'mlx::core']]],
|
||||
['allgather_41',['AllGather',['../classmlx_1_1core_1_1distributed_1_1_all_gather.html',1,'mlx::core::distributed::AllGather'],['../classmlx_1_1core_1_1distributed_1_1_all_gather.html#af4b10a5b61f160fb64353057c185b661',1,'mlx::core::distributed::AllGather::AllGather()']]],
|
||||
|
3
docs/build/html/search/all_11.js
vendored
3
docs/build/html/search/all_11.js
vendored
@ -27,5 +27,6 @@ var searchData=
|
||||
['queue_24',['queue',['../structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d',1,'mlx::core::metal::DeviceStream']]],
|
||||
['quiet_5fnan_25',['quiet_NaN',['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#aebeb07c01984be246bc2d1b8f8e4ac7b',1,'metal::_numeric_limits_impl< bfloat16_t >']]],
|
||||
['qvm_26',['qvm',['../quantized_8h.html#ad84f7d5ab9e32dbbe3ca759ae5d5d5c5',1,'quantized.h']]],
|
||||
['qvm_5fimpl_27',['qvm_impl',['../quantized_8h.html#a4a8c8db7d5d480733726fd6d1a645e12',1,'quantized.h']]]
|
||||
['qvm_5fimpl_27',['qvm_impl',['../quantized_8h.html#a1546533c5b925b2fbb3bec870ec7487a',1,'quantized.h']]],
|
||||
['qvm_5fsplit_5fk_28',['qvm_split_k',['../quantized_8h.html#ab8243818512d6078d23e6ffb65fd7bb8',1,'quantized.h']]]
|
||||
];
|
||||
|
2
docs/build/html/search/all_13.js
vendored
2
docs/build/html/search/all_13.js
vendored
@ -28,7 +28,7 @@ var searchData=
|
||||
['scheduler_25',['Scheduler',['../classmlx_1_1core_1_1scheduler_1_1_scheduler.html',1,'mlx::core::scheduler::Scheduler'],['../classmlx_1_1core_1_1scheduler_1_1_scheduler.html#a3ae42aed78a2200e9d02776fcd2316ba',1,'mlx::core::scheduler::Scheduler::Scheduler()'],['../classmlx_1_1core_1_1scheduler_1_1_scheduler.html#a61a74e3628899e66dde600e24a750648',1,'mlx::core::scheduler::Scheduler::Scheduler(const Scheduler &)=delete'],['../classmlx_1_1core_1_1scheduler_1_1_scheduler.html#ac3f77b7c93220dadd0b3bb2e903b7059',1,'mlx::core::scheduler::Scheduler::Scheduler(Scheduler &&)=delete']]],
|
||||
['scheduler_26',['scheduler',['../namespacemlx_1_1core_1_1scheduler.html#ae856e468c2f7c8f8ec672522cc13730b',1,'mlx::core::scheduler']]],
|
||||
['scheduler_2eh_27',['scheduler.h',['../scheduler_8h.html',1,'']]],
|
||||
['sdpa_5fvector_28',['sdpa_vector',['../sdpa__vector_8h.html#a6f0d7918430064bab910bdaa6c64e927',1,'sdpa_vector.h']]],
|
||||
['sdpa_5fvector_28',['sdpa_vector',['../sdpa__vector_8h.html#a4bf36f16e16c1c62d9b243573568e5ae',1,'sdpa_vector.h']]],
|
||||
['sdpa_5fvector_2eh_29',['sdpa_vector.h',['../sdpa__vector_8h.html',1,'']]],
|
||||
['seed_30',['seed',['../classmlx_1_1core_1_1random_1_1_key_sequence.html#a9f19c5da2031cba50d0ff996924347d8',1,'mlx::core::random::KeySequence::seed()'],['../namespacemlx_1_1core_1_1random.html#ac4ad325b613257306df74595d3d0e23b',1,'mlx::core::random::seed()']]],
|
||||
['seek_31',['seek',['../structmlx_1_1core_1_1_contiguous_iterator.html#a24719ee9e8667885d29c2ad74445520c',1,'mlx::core::ContiguousIterator::seek()'],['../classmlx_1_1core_1_1io_1_1_reader.html#acea55078bd39ccaa27a9a36f17a39cd1',1,'mlx::core::io::Reader::seek()'],['../classmlx_1_1core_1_1io_1_1_writer.html#a9c1716dda53aa36faea9c8fb1a3e34d4',1,'mlx::core::io::Writer::seek()'],['../classmlx_1_1core_1_1io_1_1_parallel_file_reader.html#a673c16b669f3cee13f387b7b0a1f39f7',1,'mlx::core::io::ParallelFileReader::seek()'],['../classmlx_1_1core_1_1io_1_1_file_writer.html#a9646f4ea048ae58719daeb588e2de433',1,'mlx::core::io::FileWriter::seek()']]],
|
||||
|
230
docs/build/html/search/all_3.js
vendored
230
docs/build/html/search/all_3.js
vendored
@ -35,118 +35,120 @@ var searchData=
|
||||
['cmplx_3c_20thigh_20_3e_32',['cmplx< Thigh >',['../structpocketfft_1_1detail_1_1cmplx.html',1,'pocketfft::detail']]],
|
||||
['cndarr_33',['cndarr',['../classpocketfft_1_1detail_1_1cndarr.html',1,'pocketfft::detail::cndarr< T >'],['../classpocketfft_1_1detail_1_1cndarr.html#abf73f1b4ddcfb27d7f85cfa441607129',1,'pocketfft::detail::cndarr::cndarr()']]],
|
||||
['col_5fcontiguous_34',['col_contiguous',['../structmlx_1_1core_1_1array_1_1_flags.html#ae24709026598d635e6b5c24a15f8a802',1,'mlx::core::array::Flags']]],
|
||||
['col_5freduce_5flooped_35',['col_reduce_looped',['../reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385',1,'reduce_col.h']]],
|
||||
['col_5freduce_5fsmall_36',['col_reduce_small',['../reduce__col_8h.html#adf7aeb18cd1d5042cf6d9b46b582d8ce',1,'reduce_col.h']]],
|
||||
['collapse_5fcontiguous_5fdims_37',['collapse_contiguous_dims',['../namespacemlx_1_1core.html#a38fe6ec5220d13d96c7dad7556d2b613',1,'mlx::core::collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< int64_t > > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#af2895f9b0083efd8221275eb8cadccbe',1,'mlx::core::collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< size_t > > &strides, size_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#a90e2b6edc0fe82230cb93f5ea39febb4',1,'mlx::core::collapse_contiguous_dims(const std::vector< array > &xs, size_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#ac813412cce77fc1340dcfefc6e099276',1,'mlx::core::collapse_contiguous_dims(Arrays &&... xs)'],['../namespacemlx_1_1core.html#aab3cc7f3808934ae0727b920eba231bd',1,'mlx::core::collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< int64_t > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#a1e0cbcf109d32794ffc8efc7302ba9b0',1,'mlx::core::collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< size_t > &strides, size_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#a4ee50bfb240512d0c0ce151dfe2c74ef',1,'mlx::core::collapse_contiguous_dims(const array &a, size_t size_cap=std::numeric_limits< int32_t >::max())']]],
|
||||
['commandencoder_38',['CommandEncoder',['../structmlx_1_1core_1_1metal_1_1_command_encoder.html',1,'mlx::core::metal::CommandEncoder'],['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#a2334774486f447213ee997e55c2e52a3',1,'mlx::core::metal::CommandEncoder::CommandEncoder(MTL::CommandBuffer *cbuf)'],['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#ac68ca977b5bde5434284ce7979647f14',1,'mlx::core::metal::CommandEncoder::CommandEncoder(const CommandEncoder &)=delete']]],
|
||||
['commit_5fcommand_5fbuffer_39',['commit_command_buffer',['../classmlx_1_1core_1_1metal_1_1_device.html#a95248f1387824067fd4fed23ace5ac0c',1,'mlx::core::metal::Device']]],
|
||||
['commonallocator_40',['CommonAllocator',['../classmlx_1_1core_1_1allocator_1_1_common_allocator.html',1,'mlx::core::allocator']]],
|
||||
['communication_5fstream_41',['communication_stream',['../namespacemlx_1_1core_1_1distributed_1_1detail.html#ac3612edf0e0e18c1e4ba0ce7c6e35cd6',1,'mlx::core::distributed::detail']]],
|
||||
['compile_42',['compile',['../namespacemlx_1_1core.html#a3ac798e65e59fe10b7fb5c522efce782',1,'mlx::core::compile()'],['../namespacemlx_1_1core_1_1detail.html#ac3b7b09892ff7290d5f3ef26cb444329',1,'mlx::core::detail::compile()']]],
|
||||
['compile_2eh_43',['compile.h',['../compile_8h.html',1,'']]],
|
||||
['compile_5favailable_5ffor_5fdevice_44',['compile_available_for_device',['../namespacemlx_1_1core_1_1detail.html#aeeff2ba6ec3d9d4ed090de6d2681dbc2',1,'mlx::core::detail']]],
|
||||
['compile_5fclear_5fcache_45',['compile_clear_cache',['../namespacemlx_1_1core_1_1detail.html#a3fb927c209b946aefebb195993fbe4cf',1,'mlx::core::detail']]],
|
||||
['compile_5ferase_46',['compile_erase',['../namespacemlx_1_1core_1_1detail.html#a69eb76a14f845ca000f1ccb2edda0175',1,'mlx::core::detail']]],
|
||||
['compile_5fimpl_2eh_47',['compile_impl.h',['../compile__impl_8h.html',1,'']]],
|
||||
['compiled_48',['Compiled',['../classmlx_1_1core_1_1_compiled.html',1,'mlx::core::Compiled'],['../classmlx_1_1core_1_1_compiled.html#a2d8cefff835c419a48a077d306b8e051',1,'mlx::core::Compiled::Compiled()']]],
|
||||
['compiled_2eh_49',['compiled.h',['../compiled_8h.html',1,'']]],
|
||||
['compiled_5fallocate_5foutputs_50',['compiled_allocate_outputs',['../namespacemlx_1_1core.html#ab8c3c4fc05745f586de922c8266f4fce',1,'mlx::core']]],
|
||||
['compiled_5fcheck_5fcontiguity_51',['compiled_check_contiguity',['../namespacemlx_1_1core.html#a3b900ab319948c5a01a3ecd30a709027',1,'mlx::core']]],
|
||||
['compiled_5fpreamble_2eh_52',['compiled_preamble.h',['../compiled__preamble_8h.html',1,'']]],
|
||||
['compilemode_53',['CompileMode',['../namespacemlx_1_1core.html#adb15ff2b1ca5207fd4f6e631e2c3bcb4',1,'mlx::core']]],
|
||||
['complex_2eh_54',['complex.h',['../backend_2metal_2kernels_2complex_8h.html',1,'(Global Namespace)'],['../types_2complex_8h.html',1,'(Global Namespace)']]],
|
||||
['complex128_5ft_55',['complex128_t',['../structmlx_1_1core_1_1complex128__t.html',1,'mlx::core::complex128_t'],['../structmlx_1_1core_1_1complex128__t.html#aa15d0b805f8790f7c7b76fc7b9d677e0',1,'mlx::core::complex128_t::complex128_t(double v, double u)'],['../structmlx_1_1core_1_1complex128__t.html#abf2842253b874f9f13f39ea68a89e5b6',1,'mlx::core::complex128_t::complex128_t(std::complex< double > v)'],['../structmlx_1_1core_1_1complex128__t.html#a526fba96d7e815360cb4226af085a1bf',1,'mlx::core::complex128_t::complex128_t(T x)']]],
|
||||
['complex64_56',['complex64',['../structmlx_1_1core_1_1_dtype.html#ade845ef5dcebead13a37fe696436e1daa8c022579455bcd2c681f007e84f4e2cf',1,'mlx::core::Dtype::complex64'],['../namespacemlx_1_1core.html#af99db87e0078bfcdb383f5689bc874d4',1,'mlx::core::complex64']]],
|
||||
['complex64_5ft_57',['complex64_t',['../structcomplex64__t.html',1,'complex64_t'],['../structmlx_1_1core_1_1complex64__t.html',1,'mlx::core::complex64_t'],['../structcomplex64__t.html#adbd392a5e92d31997380ad0a38be4be8',1,'complex64_t::complex64_t(float real, float imag)'],['../structcomplex64__t.html#a29782289bb90d6294099667b86509cd3',1,'complex64_t::complex64_t()'],['../structcomplex64__t.html#a905b048d70eb8d748a62454268242291',1,'complex64_t::complex64_t() threadgroup'],['../structcomplex64__t.html#a33a2452eb33b5ed53655773539c357a5',1,'complex64_t::complex64_t(T x) thread'],['../structcomplex64__t.html#a89b65ace8588b7bf215355f705eb23d9',1,'complex64_t::complex64_t(T x) threadgroup'],['../structcomplex64__t.html#ac81b486f642fb3b26c5d659917bdbcd0',1,'complex64_t::complex64_t(T x) device'],['../structcomplex64__t.html#a0a27a41206400f1e62b60ceb56960c93',1,'complex64_t::complex64_t(T x) const ant'],['../structmlx_1_1core_1_1complex64__t.html#a697cc973ae27d63c8e00d830e780bd8c',1,'mlx::core::complex64_t::complex64_t(float v, float u)'],['../structmlx_1_1core_1_1complex64__t.html#ae065e39938f9c4374b4116f4c67d4d09',1,'mlx::core::complex64_t::complex64_t(std::complex< float > v)'],['../structmlx_1_1core_1_1complex64__t.html#a2232cbbe591a9d2bc228cb23fac38b50',1,'mlx::core::complex64_t::complex64_t(T x)']]],
|
||||
['complex_5fbinop_58',['complex_binop',['../types_2complex_8h.html#a9c7995d495359894e1b30c0f1678d6bd',1,'complex.h']]],
|
||||
['complex_5fbinop_5fhelper_59',['complex_binop_helper',['../types_2complex_8h.html#ac6890f9852de12339b09b65757ebc8c4',1,'complex.h']]],
|
||||
['complex_5fmul_60',['complex_mul',['../radix_8h.html#a5bfc53b531214c9ce277bebc18aa67d6',1,'radix.h']]],
|
||||
['complex_5fmul_5fconj_61',['complex_mul_conj',['../radix_8h.html#a0e2dfd3d1dda09f47ccc64eec35629f3',1,'radix.h']]],
|
||||
['complexfloating_62',['complexfloating',['../structmlx_1_1core_1_1_dtype.html#ac091c39cbd6686ef69aa1e5a2425aa2dafb203630099d501ff7c255a574bc4812',1,'mlx::core::Dtype::complexfloating'],['../namespacemlx_1_1core.html#a70b8e88c9df750af984757105af33423',1,'mlx::core::complexfloating']]],
|
||||
['compute_5fstrided_5findices_63',['compute_strided_indices',['../struct_read_writer.html#a7c903fbb8b85a856ba5564d7df537cdf',1,'ReadWriter']]],
|
||||
['concatenate_64',['Concatenate',['../classmlx_1_1core_1_1_concatenate.html',1,'mlx::core::Concatenate'],['../classmlx_1_1core_1_1_concatenate.html#acff07853de2d31faeec7c4ca40ce0888',1,'mlx::core::Concatenate::Concatenate()']]],
|
||||
['concatenate_65',['concatenate',['../group__ops.html#gabdc36fa65697d0361c8d67495de77129',1,'mlx::core::concatenate(const std::vector< array > &arrays, int axis, StreamOrDevice s={})'],['../group__ops.html#gaa95c34ca3a8877f2c50cb60e7fa312b8',1,'mlx::core::concatenate(const std::vector< array > &arrays, StreamOrDevice s={})']]],
|
||||
['concatenate_5fgpu_66',['concatenate_gpu',['../namespacemlx_1_1core.html#a050299d0d366ca5c9d09d1004dcc3e7d',1,'mlx::core']]],
|
||||
['concurrent_5fqueue_67',['concurrent_queue',['../classpocketfft_1_1detail_1_1threading_1_1concurrent__queue.html',1,'pocketfft::detail::threading']]],
|
||||
['concurrent_5fqueue_3c_20std_3a_3afunction_3c_20void_28_29_3e_20_3e_68',['concurrent_queue< std::function< void()> >',['../classpocketfft_1_1detail_1_1threading_1_1concurrent__queue.html',1,'pocketfft::detail::threading']]],
|
||||
['concurrentcontext_69',['ConcurrentContext',['../structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html',1,'mlx::core::metal::CommandEncoder::ConcurrentContext'],['../structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html#aee044d7729739c96e845823f9ecc5174',1,'mlx::core::metal::CommandEncoder::ConcurrentContext::ConcurrentContext()']]],
|
||||
['cond_70',['cond',['../structmlx_1_1core_1_1scheduler_1_1_stream_thread.html#a4ffd524d6a5bedd1a303b63bdde6701c',1,'mlx::core::scheduler::StreamThread']]],
|
||||
['conj_71',['conj',['../namespacepocketfft_1_1detail.html#a66d79051d502046a9b9f103e744dbad3',1,'pocketfft::detail']]],
|
||||
['conjugate_72',['Conjugate',['../struct_conjugate.html',1,'Conjugate'],['../classmlx_1_1core_1_1_conjugate.html',1,'mlx::core::Conjugate'],['../structmlx_1_1core_1_1detail_1_1_conjugate.html',1,'mlx::core::detail::Conjugate'],['../classmlx_1_1core_1_1_conjugate.html#a627f9e6a8729fb3ffb3ca3228d007c87',1,'mlx::core::Conjugate::Conjugate()']]],
|
||||
['conjugate_73',['conjugate',['../group__ops.html#ga5b596906bf8cdc8d97ed6ddc9aeb4c23',1,'mlx::core']]],
|
||||
['contiguous_74',['contiguous',['../structmlx_1_1core_1_1array_1_1_flags.html#afd0ab11e7a486a2a8e50ee84b971ac8a',1,'mlx::core::array::Flags']]],
|
||||
['contiguous_5fscan_75',['contiguous_scan',['../scan_8h.html#a60d279b9add7d56639bb209408f09d79',1,'scan.h']]],
|
||||
['contiguousallreduce_76',['ContiguousAllReduce',['../namespacemlx_1_1core.html#a12412984a1cabfe1189942c898f8fe65ae4e34c7154eb8dc47aa8503209730424',1,'mlx::core']]],
|
||||
['contiguousiterator_77',['ContiguousIterator',['../structmlx_1_1core_1_1_contiguous_iterator.html',1,'mlx::core::ContiguousIterator< StrideT >'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a68794af4a442d3d8ac4647817af8e1f6',1,'mlx::core::ContiguousIterator::ContiguousIterator()'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a6cb378408b6f546eeb6ade1a4faafe3c',1,'mlx::core::ContiguousIterator::ContiguousIterator(const array &a)'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a16bdacb53f65b7284068cd49d4cba292',1,'mlx::core::ContiguousIterator::ContiguousIterator(const std::vector< int > &shape, const std::vector< StrideT > &strides, int dims)']]],
|
||||
['contiguousreduce_78',['ContiguousReduce',['../namespacemlx_1_1core.html#a12412984a1cabfe1189942c898f8fe65ad2547f25dffe8d8936dbec25601cfc84',1,'mlx::core']]],
|
||||
['contiguousstridedreduce_79',['ContiguousStridedReduce',['../namespacemlx_1_1core.html#a12412984a1cabfe1189942c898f8fe65ab48dac7508a2c790de1bdc33f29177ed',1,'mlx::core']]],
|
||||
['conv_80',['conv',['../namespacemlx_1_1core_1_1metal.html#ab1704e853394c725668c06752ebb5c24',1,'mlx::core::metal']]],
|
||||
['conv_2eh_81',['conv.h',['../conv_8h.html',1,'']]],
|
||||
['conv1d_82',['conv1d',['../group__ops.html#ga30d47e08093c03a3676f235f9f559411',1,'mlx::core']]],
|
||||
['conv2d_83',['conv2d',['../group__ops.html#ga73b02833229678786e7f302d458d5a83',1,'mlx::core']]],
|
||||
['conv2dgeneralbaseinfo_84',['Conv2DGeneralBaseInfo',['../structmlx_1_1steel_1_1_conv2_d_general_base_info.html',1,'mlx::steel']]],
|
||||
['conv2dgeneraljumpparams_85',['Conv2DGeneralJumpParams',['../structmlx_1_1steel_1_1_conv2_d_general_jump_params.html',1,'mlx::steel']]],
|
||||
['conv2dinputblockloadergeneral_86',['Conv2DInputBlockLoaderGeneral',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_general.html',1,'mlx::steel::Conv2DInputBlockLoaderGeneral< T, BM, BN, BK, tgp_size, tgp_padding >'],['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_general.html#a1d83af561a483432bf8dcb42e734b23b',1,'mlx::steel::Conv2DInputBlockLoaderGeneral::Conv2DInputBlockLoaderGeneral()']]],
|
||||
['conv2dinputblockloaderlargefilter_87',['Conv2DInputBlockLoaderLargeFilter',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_large_filter.html',1,'mlx::steel::Conv2DInputBlockLoaderLargeFilter< T, BM, BN, BK, tgp_size, tgp_padding >'],['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_large_filter.html#a8755116a535539744e4947bc69f9c50f',1,'mlx::steel::Conv2DInputBlockLoaderLargeFilter::Conv2DInputBlockLoaderLargeFilter()']]],
|
||||
['conv2dinputblockloadersmallchannels_88',['Conv2DInputBlockLoaderSmallChannels',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_channels.html',1,'mlx::steel::Conv2DInputBlockLoaderSmallChannels< T, BM, BN, BK, tgp_size, n_channels, tgp_padding >'],['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_channels.html#ab9fd3fdeab94470dde3326f1dd5c455a',1,'mlx::steel::Conv2DInputBlockLoaderSmallChannels::Conv2DInputBlockLoaderSmallChannels()']]],
|
||||
['conv2dinputblockloadersmallfilter_89',['Conv2DInputBlockLoaderSmallFilter',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_filter.html',1,'mlx::steel::Conv2DInputBlockLoaderSmallFilter< T, BM, BN, BK, tgp_size, tgp_padding >'],['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_filter.html#a0a2cbf57c51cd928722e3f06aafcf933',1,'mlx::steel::Conv2DInputBlockLoaderSmallFilter::Conv2DInputBlockLoaderSmallFilter()']]],
|
||||
['conv2dweightblockloader_90',['Conv2DWeightBlockLoader',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader.html',1,'mlx::steel::Conv2DWeightBlockLoader< T, BM, BN, BK, tgp_size, tgp_padding >'],['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader.html#a9a7dca3512b64cffb6eac305d795831c',1,'mlx::steel::Conv2DWeightBlockLoader::Conv2DWeightBlockLoader()']]],
|
||||
['conv2dweightblockloadergeneral_91',['Conv2DWeightBlockLoaderGeneral',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_general.html',1,'mlx::steel::Conv2DWeightBlockLoaderGeneral< T, BM, BN, BK, tgp_size, tgp_padding >'],['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_general.html#ad0550fabbdc9297559381a5b488e9af1',1,'mlx::steel::Conv2DWeightBlockLoaderGeneral::Conv2DWeightBlockLoaderGeneral()']]],
|
||||
['conv2dweightblockloadersmallchannels_92',['Conv2DWeightBlockLoaderSmallChannels',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_small_channels.html',1,'mlx::steel::Conv2DWeightBlockLoaderSmallChannels< T, BM, BN, BK, tgp_size, n_channels, tgp_padding >'],['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_small_channels.html#ae1806ea1c19713819dee83a38ab35fa6',1,'mlx::steel::Conv2DWeightBlockLoaderSmallChannels::Conv2DWeightBlockLoaderSmallChannels()']]],
|
||||
['conv3d_93',['conv3d',['../group__ops.html#ga6e9907d2f14dc4803e4306b3dbc4b3ca',1,'mlx::core']]],
|
||||
['conv_5fgeneral_94',['conv_general',['../group__ops.html#ga2236e5dfc7e52e28abf6c21675d0a51e',1,'mlx::core::conv_general(array input, array weight, std::vector< int > stride={}, std::vector< int > padding_lo={}, std::vector< int > padding_hi={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})'],['../group__ops.html#gab59f89942cd1efaadffe9e8762e3c99d',1,'mlx::core::conv_general(const array &input, const array &weight, std::vector< int > stride={}, std::vector< int > padding={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})']]],
|
||||
['conv_5ftranspose1d_95',['conv_transpose1d',['../group__ops.html#gaa30bf1adcd78d1c2595d07b215731714',1,'mlx::core']]],
|
||||
['conv_5ftranspose2d_96',['conv_transpose2d',['../group__ops.html#gaebb59971cb9bc45005dc1d398e4f0a3d',1,'mlx::core']]],
|
||||
['conv_5ftranspose3d_97',['conv_transpose3d',['../group__ops.html#ga8db814da631d9cd32a8d6563bf4ac530',1,'mlx::core']]],
|
||||
['convolution_98',['Convolution',['../classmlx_1_1core_1_1_convolution.html',1,'mlx::core::Convolution'],['../classmlx_1_1core_1_1_convolution.html#a6f1de77b719bb13217b0d8c64cabb8ef',1,'mlx::core::Convolution::Convolution()']]],
|
||||
['copy_99',['Copy',['../classmlx_1_1core_1_1_copy.html',1,'mlx::core::Copy'],['../classmlx_1_1core_1_1_copy.html#a6243e044af119105ffaaed7d405cd584',1,'mlx::core::Copy::Copy()']]],
|
||||
['copy_100',['copy',['../namespacemlx_1_1core.html#a479648542a2bea151b947b18f0e79dd2',1,'mlx::core::copy()'],['../namespacemlx_1_1core_1_1metal.html#aa215e631e2680f04a591b88d91571719',1,'mlx::core::metal::copy()'],['../group__ops.html#gae306e93af12f774bd80bad6c231b09d6',1,'mlx::core::copy()']]],
|
||||
['copy_2eh_101',['copy.h',['../common_2copy_8h.html',1,'(Global Namespace)'],['../metal_2copy_8h.html',1,'(Global Namespace)'],['../metal_2kernels_2copy_8h.html',1,'(Global Namespace)']]],
|
||||
['copy_5fg_102',['copy_g',['../metal_2kernels_2copy_8h.html#a778ce2dbfbaa23b24bd5efbe68448c36',1,'copy.h']]],
|
||||
['copy_5fg_5fnd1_103',['copy_g_nd1',['../metal_2kernels_2copy_8h.html#aba4530a7db6a61ca36f50e4f5e58fb77',1,'copy.h']]],
|
||||
['copy_5fg_5fnd2_104',['copy_g_nd2',['../metal_2kernels_2copy_8h.html#aee678c7c31119f3e609685589f37490c',1,'copy.h']]],
|
||||
['copy_5fg_5fnd3_105',['copy_g_nd3',['../metal_2kernels_2copy_8h.html#a821f8f3f3891159a295c66fc25aed1ff',1,'copy.h']]],
|
||||
['copy_5fgg_106',['copy_gg',['../metal_2kernels_2copy_8h.html#a1e39c2683eeaf05955e7619fbd34aea5',1,'copy.h']]],
|
||||
['copy_5fgg_5fnd1_107',['copy_gg_nd1',['../metal_2kernels_2copy_8h.html#a3278d9c999718bee3ccbe2922f501bf1',1,'copy.h']]],
|
||||
['copy_5fgg_5fnd2_108',['copy_gg_nd2',['../metal_2kernels_2copy_8h.html#a3e2d3cc7f34f56170409b6735f51a950',1,'copy.h']]],
|
||||
['copy_5fgg_5fnd3_109',['copy_gg_nd3',['../metal_2kernels_2copy_8h.html#a59f43b5bffed936d7559ceb06a10aabd',1,'copy.h']]],
|
||||
['copy_5fgpu_110',['copy_gpu',['../namespacemlx_1_1core.html#addaa46a13ac2deb1d9ce621338320e0e',1,'mlx::core::copy_gpu(const array &src, array &out, CopyType ctype, const Stream &s)'],['../namespacemlx_1_1core.html#a6a6f4e46c8fc44fdc74c50ace02bcf38',1,'mlx::core::copy_gpu(const array &src, array &out, CopyType ctype)']]],
|
||||
['copy_5fgpu_5finplace_111',['copy_gpu_inplace',['../namespacemlx_1_1core.html#a69e30f5d30a6d72ac0ffe4886f24b7ba',1,'mlx::core::copy_gpu_inplace(const array &in, array &out, const std::vector< int > &data_shape, const std::vector< stride_t > &i_strides, const std::vector< stride_t > &o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype, const Stream &s)'],['../namespacemlx_1_1core.html#a8e1ccb0ed9387b0a789311d9f8964803',1,'mlx::core::copy_gpu_inplace(const array &src, array &out, CopyType ctype, const Stream &s)'],['../namespacemlx_1_1core.html#ae55b801b09ccf55cba96278163a9b1ef',1,'mlx::core::copy_gpu_inplace(const array &in, array &out, const std::vector< int64_t > &istride, int64_t ioffset, CopyType ctype, const Stream &s)']]],
|
||||
['copy_5fhartley_112',['copy_hartley',['../namespacepocketfft_1_1detail.html#abac3fcc8ce83800d228774f64c28d4c3',1,'pocketfft::detail::copy_hartley(const multi_iter< vlen > &it, const vtype_t< T > *src, ndarr< T > &dst)'],['../namespacepocketfft_1_1detail.html#ae7b44d2773d9d06a9787aff01d66b3ed',1,'pocketfft::detail::copy_hartley(const multi_iter< vlen > &it, const T *src, ndarr< T > &dst)']]],
|
||||
['copy_5finplace_113',['copy_inplace',['../namespacemlx_1_1core.html#a98495894a796b2cc6d022e7a03432c64',1,'mlx::core::copy_inplace(const array &src, array &dst, CopyType ctype)'],['../namespacemlx_1_1core.html#aad636e2d0b2f882cadd1b438f4daa9ed',1,'mlx::core::copy_inplace(const array &src, array &dst, const std::vector< int > &data_shape, const std::vector< stride_t > &i_strides, const std::vector< stride_t > &o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype)']]],
|
||||
['copy_5finput_114',['copy_input',['../namespacepocketfft_1_1detail.html#aff05be3064743c1143b19318ab12ad4a',1,'pocketfft::detail::copy_input(const multi_iter< vlen > &it, const cndarr< cmplx< T > > &src, cmplx< vtype_t< T > > *dst)'],['../namespacepocketfft_1_1detail.html#a30fc708f9d8f9cfa74194925c7863c0a',1,'pocketfft::detail::copy_input(const multi_iter< vlen > &it, const cndarr< T > &src, vtype_t< T > *dst)'],['../namespacepocketfft_1_1detail.html#a3387bd35f237870e42b8461769e6aec4',1,'pocketfft::detail::copy_input(const multi_iter< vlen > &it, const cndarr< T > &src, T *dst)']]],
|
||||
['copy_5foutput_115',['copy_output',['../namespacepocketfft_1_1detail.html#a1523a037300a8da05db210b802d9cb0e',1,'pocketfft::detail::copy_output(const multi_iter< vlen > &it, const cmplx< vtype_t< T > > *src, ndarr< cmplx< T > > &dst)'],['../namespacepocketfft_1_1detail.html#a21980853aca4d92ed06e3dcffe7ef660',1,'pocketfft::detail::copy_output(const multi_iter< vlen > &it, const vtype_t< T > *src, ndarr< T > &dst)'],['../namespacepocketfft_1_1detail.html#a310481c334e46674710ba794ad7403c0',1,'pocketfft::detail::copy_output(const multi_iter< vlen > &it, const T *src, ndarr< T > &dst)']]],
|
||||
['copy_5fs_116',['copy_s',['../metal_2kernels_2copy_8h.html#aef09f9b9475345b1bba121d037d222ea',1,'copy.h']]],
|
||||
['copy_5fs2_117',['copy_s2',['../metal_2kernels_2copy_8h.html#a8023e9335cc5334847a8d315042be3a3',1,'copy.h']]],
|
||||
['copy_5fshared_5fbuffer_118',['copy_shared_buffer',['../classmlx_1_1core_1_1array.html#a28df7a333d90a311c49bc4bce7a1ad6d',1,'mlx::core::array::copy_shared_buffer(const array &other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)'],['../classmlx_1_1core_1_1array.html#a92974c656c35a972ad241f80584bbd29',1,'mlx::core::array::copy_shared_buffer(const array &other)']]],
|
||||
['copy_5fv_119',['copy_v',['../metal_2kernels_2copy_8h.html#ae26a13e0c8e6c15f7b10078e65970659',1,'copy.h']]],
|
||||
['copy_5fv2_120',['copy_v2',['../metal_2kernels_2copy_8h.html#aee14a5326f53d9b30b0b38e27d180ef3',1,'copy.h']]],
|
||||
['copytype_121',['CopyType',['../namespacemlx_1_1core.html#abd84ff6c5245e4e170b2ef5247594337',1,'mlx::core']]],
|
||||
['core_20array_20operations_122',['Core array operations',['../group__ops.html',1,'']]],
|
||||
['cos_123',['Cos',['../struct_cos.html',1,'Cos'],['../classmlx_1_1core_1_1_cos.html',1,'mlx::core::Cos'],['../structmlx_1_1core_1_1detail_1_1_cos.html',1,'mlx::core::detail::Cos'],['../classmlx_1_1core_1_1_cos.html#a2acb9fcf0901462189c476756fd99995',1,'mlx::core::Cos::Cos()']]],
|
||||
['cos_124',['cos',['../namespacepocketfft_1_1detail.html#a499c1e8b7d79a5272af024f46c63ff9d',1,'pocketfft::detail::cos()'],['../namespacemetal.html#a2fa4778a6fe2fa43253ea724e5a608a3',1,'metal::cos()'],['../namespacemetal_1_1fast.html#a75b6bb32fa3870eda46a7bfc9f481f88',1,'metal::fast::cos()'],['../namespacemetal_1_1precise.html#ac4941f62e7d8ab9d7cabbd967aa9f220',1,'metal::precise::cos()'],['../group__ops.html#ga39dfdf72b556012aa35ff27a94116e74',1,'mlx::core::cos()']]],
|
||||
['cosh_125',['Cosh',['../struct_cosh.html',1,'Cosh'],['../classmlx_1_1core_1_1_cosh.html',1,'mlx::core::Cosh'],['../structmlx_1_1core_1_1detail_1_1_cosh.html',1,'mlx::core::detail::Cosh'],['../classmlx_1_1core_1_1_cosh.html#a44e8ac2e09a55ec32e9dc6641eedc8f1',1,'mlx::core::Cosh::Cosh()']]],
|
||||
['cosh_126',['cosh',['../namespacemetal.html#a8a68a88cc110830d057dbd71431b93c0',1,'metal::cosh()'],['../namespacemetal_1_1fast.html#a31544ad9de28012a4ddda86e3966a77e',1,'metal::fast::cosh()'],['../namespacemetal_1_1precise.html#a72d86d508300a9b58f4ccbbe70da4fbc',1,'metal::precise::cosh()'],['../group__ops.html#ga2181b71cda88007a3092be4795ff0715',1,'mlx::core::cosh()']]],
|
||||
['cosine_127',['cosine',['../structpocketfft_1_1detail_1_1_exec_dcst.html#a185023fc1e386cc8f233b79c49c1fd8a',1,'pocketfft::detail::ExecDcst']]],
|
||||
['cospi_128',['cospi',['../namespacemetal.html#a5c2f37939ad705ddea4409d3bedb8ce1',1,'metal::cospi()'],['../namespacemetal_1_1fast.html#a9906b41f75319b384ffb570cc94d67ce',1,'metal::fast::cospi()'],['../namespacemetal_1_1precise.html#a2392b78bd196efdbbac65901c4ab20e7',1,'metal::precise::cospi()']]],
|
||||
['cost_5fguess_129',['cost_guess',['../structpocketfft_1_1detail_1_1util.html#ad3d874bc3fb0048df2270779a15d4bd0',1,'pocketfft::detail::util']]],
|
||||
['count_5fdown_130',['count_down',['../classpocketfft_1_1detail_1_1threading_1_1latch.html#a81d6597189b40410e35f3cd653fd1342',1,'pocketfft::detail::threading::latch']]],
|
||||
['cpu_131',['cpu',['../structmlx_1_1core_1_1_device.html#a69ee81924251dec96f1945c9d91506fd',1,'mlx::core::Device::cpu'],['../structmlx_1_1core_1_1_device.html#ac45b3de9b3458d8f31005136cde20fdbad9747e2da342bdb995f6389533ad1a3d',1,'mlx::core::Device::cpu']]],
|
||||
['cross_132',['cross',['../namespacemlx_1_1core_1_1linalg.html#abcda3fbda45183c21e7f27aa0dde64e6',1,'mlx::core::linalg']]],
|
||||
['ctile_133',['Ctile',['../structmlx_1_1steel_1_1_block_m_m_a.html#a81838da5d81e62d372d581be599c5a88',1,'mlx::steel::BlockMMA']]],
|
||||
['cummax_134',['CumMax',['../struct_cum_max.html',1,'']]],
|
||||
['cummax_135',['cummax',['../group__ops.html#gaee37cac8476e8f8d666bcded5bc59143',1,'mlx::core']]],
|
||||
['cummin_136',['CumMin',['../struct_cum_min.html',1,'']]],
|
||||
['cummin_137',['cummin',['../group__ops.html#ga19c1bf6929fe8d66b9cd408946aea6a8',1,'mlx::core']]],
|
||||
['cumprod_138',['CumProd',['../struct_cum_prod.html',1,'']]],
|
||||
['cumprod_139',['cumprod',['../group__ops.html#ga0d71dfbc14ef3ed564b0c5ee26af680f',1,'mlx::core']]],
|
||||
['cumprod_3c_20bool_20_3e_140',['CumProd< bool >',['../struct_cum_prod_3_01bool_01_4.html',1,'']]],
|
||||
['cumsum_141',['CumSum',['../struct_cum_sum.html',1,'']]],
|
||||
['cumsum_142',['cumsum',['../group__ops.html#gaddc825a5c173e195ab0fda83ad630420',1,'mlx::core']]],
|
||||
['custom_143',['Custom',['../classmlx_1_1core_1_1fast_1_1_custom.html',1,'mlx::core::fast::Custom'],['../classmlx_1_1core_1_1fast_1_1_custom.html#a4186fea23f7156c38960426821fca313',1,'mlx::core::fast::Custom::Custom()']]],
|
||||
['custom_5ffunction_144',['custom_function',['../namespacemlx_1_1core.html#a8d3ca5fbaecdb995660c24cde5aeebaf',1,'mlx::core']]],
|
||||
['custom_5fvjp_145',['custom_vjp',['../namespacemlx_1_1core.html#a9290596250fa308df4c69b44483bb8aa',1,'mlx::core']]],
|
||||
['customkernel_146',['CustomKernel',['../classmlx_1_1core_1_1fast_1_1_custom_kernel.html',1,'mlx::core::fast::CustomKernel'],['../classmlx_1_1core_1_1fast_1_1_custom_kernel.html#a954893e07f0d36715b4e1e414b6f2153',1,'mlx::core::fast::CustomKernel::CustomKernel()']]],
|
||||
['customkernelshapeinfo_147',['CustomKernelShapeInfo',['../structmlx_1_1core_1_1fast_1_1_custom_kernel_shape_info.html',1,'mlx::core::fast']]],
|
||||
['customtransforms_148',['CustomTransforms',['../classmlx_1_1core_1_1_custom_transforms.html',1,'mlx::core::CustomTransforms'],['../classmlx_1_1core_1_1_custom_transforms.html#ab52abadb9c6f6db83d087c7b751be488',1,'mlx::core::CustomTransforms::CustomTransforms()']]]
|
||||
['col_5freduce_5f2pass_35',['col_reduce_2pass',['../reduce__col_8h.html#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d',1,'reduce_col.h']]],
|
||||
['col_5freduce_5flongcolumn_36',['col_reduce_longcolumn',['../reduce__col_8h.html#a5b4f4c4c247ad341ff8d31dcbbbce0eb',1,'reduce_col.h']]],
|
||||
['col_5freduce_5flooped_37',['col_reduce_looped',['../reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385',1,'reduce_col.h']]],
|
||||
['col_5freduce_5fsmall_38',['col_reduce_small',['../reduce__col_8h.html#a7c378443a2b6f4d9210db8a21a9ac4f5',1,'reduce_col.h']]],
|
||||
['collapse_5fcontiguous_5fdims_39',['collapse_contiguous_dims',['../namespacemlx_1_1core.html#a38fe6ec5220d13d96c7dad7556d2b613',1,'mlx::core::collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< int64_t > > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#af2895f9b0083efd8221275eb8cadccbe',1,'mlx::core::collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< size_t > > &strides, size_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#a90e2b6edc0fe82230cb93f5ea39febb4',1,'mlx::core::collapse_contiguous_dims(const std::vector< array > &xs, size_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#ac813412cce77fc1340dcfefc6e099276',1,'mlx::core::collapse_contiguous_dims(Arrays &&... xs)'],['../namespacemlx_1_1core.html#aab3cc7f3808934ae0727b920eba231bd',1,'mlx::core::collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< int64_t > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#a1e0cbcf109d32794ffc8efc7302ba9b0',1,'mlx::core::collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< size_t > &strides, size_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#a4ee50bfb240512d0c0ce151dfe2c74ef',1,'mlx::core::collapse_contiguous_dims(const array &a, size_t size_cap=std::numeric_limits< int32_t >::max())']]],
|
||||
['commandencoder_40',['CommandEncoder',['../structmlx_1_1core_1_1metal_1_1_command_encoder.html',1,'mlx::core::metal::CommandEncoder'],['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#a2334774486f447213ee997e55c2e52a3',1,'mlx::core::metal::CommandEncoder::CommandEncoder(MTL::CommandBuffer *cbuf)'],['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#ac68ca977b5bde5434284ce7979647f14',1,'mlx::core::metal::CommandEncoder::CommandEncoder(const CommandEncoder &)=delete']]],
|
||||
['commit_5fcommand_5fbuffer_41',['commit_command_buffer',['../classmlx_1_1core_1_1metal_1_1_device.html#a95248f1387824067fd4fed23ace5ac0c',1,'mlx::core::metal::Device']]],
|
||||
['commonallocator_42',['CommonAllocator',['../classmlx_1_1core_1_1allocator_1_1_common_allocator.html',1,'mlx::core::allocator']]],
|
||||
['communication_5fstream_43',['communication_stream',['../namespacemlx_1_1core_1_1distributed_1_1detail.html#ac3612edf0e0e18c1e4ba0ce7c6e35cd6',1,'mlx::core::distributed::detail']]],
|
||||
['compile_44',['compile',['../namespacemlx_1_1core.html#a3ac798e65e59fe10b7fb5c522efce782',1,'mlx::core::compile()'],['../namespacemlx_1_1core_1_1detail.html#ac3b7b09892ff7290d5f3ef26cb444329',1,'mlx::core::detail::compile()']]],
|
||||
['compile_2eh_45',['compile.h',['../compile_8h.html',1,'']]],
|
||||
['compile_5favailable_5ffor_5fdevice_46',['compile_available_for_device',['../namespacemlx_1_1core_1_1detail.html#aeeff2ba6ec3d9d4ed090de6d2681dbc2',1,'mlx::core::detail']]],
|
||||
['compile_5fclear_5fcache_47',['compile_clear_cache',['../namespacemlx_1_1core_1_1detail.html#a3fb927c209b946aefebb195993fbe4cf',1,'mlx::core::detail']]],
|
||||
['compile_5ferase_48',['compile_erase',['../namespacemlx_1_1core_1_1detail.html#a69eb76a14f845ca000f1ccb2edda0175',1,'mlx::core::detail']]],
|
||||
['compile_5fimpl_2eh_49',['compile_impl.h',['../compile__impl_8h.html',1,'']]],
|
||||
['compiled_50',['Compiled',['../classmlx_1_1core_1_1_compiled.html',1,'mlx::core::Compiled'],['../classmlx_1_1core_1_1_compiled.html#a2d8cefff835c419a48a077d306b8e051',1,'mlx::core::Compiled::Compiled()']]],
|
||||
['compiled_2eh_51',['compiled.h',['../compiled_8h.html',1,'']]],
|
||||
['compiled_5fallocate_5foutputs_52',['compiled_allocate_outputs',['../namespacemlx_1_1core.html#ab8c3c4fc05745f586de922c8266f4fce',1,'mlx::core']]],
|
||||
['compiled_5fcheck_5fcontiguity_53',['compiled_check_contiguity',['../namespacemlx_1_1core.html#a3b900ab319948c5a01a3ecd30a709027',1,'mlx::core']]],
|
||||
['compiled_5fpreamble_2eh_54',['compiled_preamble.h',['../compiled__preamble_8h.html',1,'']]],
|
||||
['compilemode_55',['CompileMode',['../namespacemlx_1_1core.html#adb15ff2b1ca5207fd4f6e631e2c3bcb4',1,'mlx::core']]],
|
||||
['complex_2eh_56',['complex.h',['../backend_2metal_2kernels_2complex_8h.html',1,'(Global Namespace)'],['../types_2complex_8h.html',1,'(Global Namespace)']]],
|
||||
['complex128_5ft_57',['complex128_t',['../structmlx_1_1core_1_1complex128__t.html',1,'mlx::core::complex128_t'],['../structmlx_1_1core_1_1complex128__t.html#aa15d0b805f8790f7c7b76fc7b9d677e0',1,'mlx::core::complex128_t::complex128_t(double v, double u)'],['../structmlx_1_1core_1_1complex128__t.html#abf2842253b874f9f13f39ea68a89e5b6',1,'mlx::core::complex128_t::complex128_t(std::complex< double > v)'],['../structmlx_1_1core_1_1complex128__t.html#a526fba96d7e815360cb4226af085a1bf',1,'mlx::core::complex128_t::complex128_t(T x)']]],
|
||||
['complex64_58',['complex64',['../structmlx_1_1core_1_1_dtype.html#ade845ef5dcebead13a37fe696436e1daa8c022579455bcd2c681f007e84f4e2cf',1,'mlx::core::Dtype::complex64'],['../namespacemlx_1_1core.html#af99db87e0078bfcdb383f5689bc874d4',1,'mlx::core::complex64']]],
|
||||
['complex64_5ft_59',['complex64_t',['../structcomplex64__t.html',1,'complex64_t'],['../structmlx_1_1core_1_1complex64__t.html',1,'mlx::core::complex64_t'],['../structcomplex64__t.html#adbd392a5e92d31997380ad0a38be4be8',1,'complex64_t::complex64_t(float real, float imag)'],['../structcomplex64__t.html#a29782289bb90d6294099667b86509cd3',1,'complex64_t::complex64_t()'],['../structcomplex64__t.html#a905b048d70eb8d748a62454268242291',1,'complex64_t::complex64_t() threadgroup'],['../structcomplex64__t.html#a33a2452eb33b5ed53655773539c357a5',1,'complex64_t::complex64_t(T x) thread'],['../structcomplex64__t.html#a89b65ace8588b7bf215355f705eb23d9',1,'complex64_t::complex64_t(T x) threadgroup'],['../structcomplex64__t.html#ac81b486f642fb3b26c5d659917bdbcd0',1,'complex64_t::complex64_t(T x) device'],['../structcomplex64__t.html#a0a27a41206400f1e62b60ceb56960c93',1,'complex64_t::complex64_t(T x) const ant'],['../structmlx_1_1core_1_1complex64__t.html#a697cc973ae27d63c8e00d830e780bd8c',1,'mlx::core::complex64_t::complex64_t(float v, float u)'],['../structmlx_1_1core_1_1complex64__t.html#ae065e39938f9c4374b4116f4c67d4d09',1,'mlx::core::complex64_t::complex64_t(std::complex< float > v)'],['../structmlx_1_1core_1_1complex64__t.html#a2232cbbe591a9d2bc228cb23fac38b50',1,'mlx::core::complex64_t::complex64_t(T x)']]],
|
||||
['complex_5fbinop_60',['complex_binop',['../types_2complex_8h.html#a9c7995d495359894e1b30c0f1678d6bd',1,'complex.h']]],
|
||||
['complex_5fbinop_5fhelper_61',['complex_binop_helper',['../types_2complex_8h.html#ac6890f9852de12339b09b65757ebc8c4',1,'complex.h']]],
|
||||
['complex_5fmul_62',['complex_mul',['../radix_8h.html#a5bfc53b531214c9ce277bebc18aa67d6',1,'radix.h']]],
|
||||
['complex_5fmul_5fconj_63',['complex_mul_conj',['../radix_8h.html#a0e2dfd3d1dda09f47ccc64eec35629f3',1,'radix.h']]],
|
||||
['complexfloating_64',['complexfloating',['../structmlx_1_1core_1_1_dtype.html#ac091c39cbd6686ef69aa1e5a2425aa2dafb203630099d501ff7c255a574bc4812',1,'mlx::core::Dtype::complexfloating'],['../namespacemlx_1_1core.html#a70b8e88c9df750af984757105af33423',1,'mlx::core::complexfloating']]],
|
||||
['compute_5fstrided_5findices_65',['compute_strided_indices',['../struct_read_writer.html#a7c903fbb8b85a856ba5564d7df537cdf',1,'ReadWriter']]],
|
||||
['concatenate_66',['Concatenate',['../classmlx_1_1core_1_1_concatenate.html',1,'mlx::core::Concatenate'],['../classmlx_1_1core_1_1_concatenate.html#acff07853de2d31faeec7c4ca40ce0888',1,'mlx::core::Concatenate::Concatenate()']]],
|
||||
['concatenate_67',['concatenate',['../group__ops.html#gabdc36fa65697d0361c8d67495de77129',1,'mlx::core::concatenate(const std::vector< array > &arrays, int axis, StreamOrDevice s={})'],['../group__ops.html#gaa95c34ca3a8877f2c50cb60e7fa312b8',1,'mlx::core::concatenate(const std::vector< array > &arrays, StreamOrDevice s={})']]],
|
||||
['concatenate_5fgpu_68',['concatenate_gpu',['../namespacemlx_1_1core.html#a050299d0d366ca5c9d09d1004dcc3e7d',1,'mlx::core']]],
|
||||
['concurrent_5fqueue_69',['concurrent_queue',['../classpocketfft_1_1detail_1_1threading_1_1concurrent__queue.html',1,'pocketfft::detail::threading']]],
|
||||
['concurrent_5fqueue_3c_20std_3a_3afunction_3c_20void_28_29_3e_20_3e_70',['concurrent_queue< std::function< void()> >',['../classpocketfft_1_1detail_1_1threading_1_1concurrent__queue.html',1,'pocketfft::detail::threading']]],
|
||||
['concurrentcontext_71',['ConcurrentContext',['../structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html',1,'mlx::core::metal::CommandEncoder::ConcurrentContext'],['../structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html#aee044d7729739c96e845823f9ecc5174',1,'mlx::core::metal::CommandEncoder::ConcurrentContext::ConcurrentContext()']]],
|
||||
['cond_72',['cond',['../structmlx_1_1core_1_1scheduler_1_1_stream_thread.html#a4ffd524d6a5bedd1a303b63bdde6701c',1,'mlx::core::scheduler::StreamThread']]],
|
||||
['conj_73',['conj',['../namespacepocketfft_1_1detail.html#a66d79051d502046a9b9f103e744dbad3',1,'pocketfft::detail']]],
|
||||
['conjugate_74',['Conjugate',['../struct_conjugate.html',1,'Conjugate'],['../classmlx_1_1core_1_1_conjugate.html',1,'mlx::core::Conjugate'],['../structmlx_1_1core_1_1detail_1_1_conjugate.html',1,'mlx::core::detail::Conjugate'],['../classmlx_1_1core_1_1_conjugate.html#a627f9e6a8729fb3ffb3ca3228d007c87',1,'mlx::core::Conjugate::Conjugate()']]],
|
||||
['conjugate_75',['conjugate',['../group__ops.html#ga5b596906bf8cdc8d97ed6ddc9aeb4c23',1,'mlx::core']]],
|
||||
['contiguous_76',['contiguous',['../structmlx_1_1core_1_1array_1_1_flags.html#afd0ab11e7a486a2a8e50ee84b971ac8a',1,'mlx::core::array::Flags']]],
|
||||
['contiguous_5fscan_77',['contiguous_scan',['../scan_8h.html#a60d279b9add7d56639bb209408f09d79',1,'scan.h']]],
|
||||
['contiguousallreduce_78',['ContiguousAllReduce',['../namespacemlx_1_1core.html#a12412984a1cabfe1189942c898f8fe65ae4e34c7154eb8dc47aa8503209730424',1,'mlx::core']]],
|
||||
['contiguousiterator_79',['ContiguousIterator',['../structmlx_1_1core_1_1_contiguous_iterator.html',1,'mlx::core::ContiguousIterator< StrideT >'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a68794af4a442d3d8ac4647817af8e1f6',1,'mlx::core::ContiguousIterator::ContiguousIterator()'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a6cb378408b6f546eeb6ade1a4faafe3c',1,'mlx::core::ContiguousIterator::ContiguousIterator(const array &a)'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a16bdacb53f65b7284068cd49d4cba292',1,'mlx::core::ContiguousIterator::ContiguousIterator(const std::vector< int > &shape, const std::vector< StrideT > &strides, int dims)']]],
|
||||
['contiguousreduce_80',['ContiguousReduce',['../namespacemlx_1_1core.html#a12412984a1cabfe1189942c898f8fe65ad2547f25dffe8d8936dbec25601cfc84',1,'mlx::core']]],
|
||||
['contiguousstridedreduce_81',['ContiguousStridedReduce',['../namespacemlx_1_1core.html#a12412984a1cabfe1189942c898f8fe65ab48dac7508a2c790de1bdc33f29177ed',1,'mlx::core']]],
|
||||
['conv_82',['conv',['../namespacemlx_1_1core_1_1metal.html#ab1704e853394c725668c06752ebb5c24',1,'mlx::core::metal']]],
|
||||
['conv_2eh_83',['conv.h',['../conv_8h.html',1,'']]],
|
||||
['conv1d_84',['conv1d',['../group__ops.html#ga30d47e08093c03a3676f235f9f559411',1,'mlx::core']]],
|
||||
['conv2d_85',['conv2d',['../group__ops.html#ga73b02833229678786e7f302d458d5a83',1,'mlx::core']]],
|
||||
['conv2dgeneralbaseinfo_86',['Conv2DGeneralBaseInfo',['../structmlx_1_1steel_1_1_conv2_d_general_base_info.html',1,'mlx::steel']]],
|
||||
['conv2dgeneraljumpparams_87',['Conv2DGeneralJumpParams',['../structmlx_1_1steel_1_1_conv2_d_general_jump_params.html',1,'mlx::steel']]],
|
||||
['conv2dinputblockloadergeneral_88',['Conv2DInputBlockLoaderGeneral',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_general.html',1,'mlx::steel::Conv2DInputBlockLoaderGeneral< T, BM, BN, BK, tgp_size, tgp_padding >'],['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_general.html#a1d83af561a483432bf8dcb42e734b23b',1,'mlx::steel::Conv2DInputBlockLoaderGeneral::Conv2DInputBlockLoaderGeneral()']]],
|
||||
['conv2dinputblockloaderlargefilter_89',['Conv2DInputBlockLoaderLargeFilter',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_large_filter.html',1,'mlx::steel::Conv2DInputBlockLoaderLargeFilter< T, BM, BN, BK, tgp_size, tgp_padding >'],['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_large_filter.html#a8755116a535539744e4947bc69f9c50f',1,'mlx::steel::Conv2DInputBlockLoaderLargeFilter::Conv2DInputBlockLoaderLargeFilter()']]],
|
||||
['conv2dinputblockloadersmallchannels_90',['Conv2DInputBlockLoaderSmallChannels',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_channels.html',1,'mlx::steel::Conv2DInputBlockLoaderSmallChannels< T, BM, BN, BK, tgp_size, n_channels, tgp_padding >'],['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_channels.html#ab9fd3fdeab94470dde3326f1dd5c455a',1,'mlx::steel::Conv2DInputBlockLoaderSmallChannels::Conv2DInputBlockLoaderSmallChannels()']]],
|
||||
['conv2dinputblockloadersmallfilter_91',['Conv2DInputBlockLoaderSmallFilter',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_filter.html',1,'mlx::steel::Conv2DInputBlockLoaderSmallFilter< T, BM, BN, BK, tgp_size, tgp_padding >'],['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_filter.html#a0a2cbf57c51cd928722e3f06aafcf933',1,'mlx::steel::Conv2DInputBlockLoaderSmallFilter::Conv2DInputBlockLoaderSmallFilter()']]],
|
||||
['conv2dweightblockloader_92',['Conv2DWeightBlockLoader',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader.html',1,'mlx::steel::Conv2DWeightBlockLoader< T, BM, BN, BK, tgp_size, tgp_padding >'],['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader.html#a9a7dca3512b64cffb6eac305d795831c',1,'mlx::steel::Conv2DWeightBlockLoader::Conv2DWeightBlockLoader()']]],
|
||||
['conv2dweightblockloadergeneral_93',['Conv2DWeightBlockLoaderGeneral',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_general.html',1,'mlx::steel::Conv2DWeightBlockLoaderGeneral< T, BM, BN, BK, tgp_size, tgp_padding >'],['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_general.html#ad0550fabbdc9297559381a5b488e9af1',1,'mlx::steel::Conv2DWeightBlockLoaderGeneral::Conv2DWeightBlockLoaderGeneral()']]],
|
||||
['conv2dweightblockloadersmallchannels_94',['Conv2DWeightBlockLoaderSmallChannels',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_small_channels.html',1,'mlx::steel::Conv2DWeightBlockLoaderSmallChannels< T, BM, BN, BK, tgp_size, n_channels, tgp_padding >'],['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_small_channels.html#ae1806ea1c19713819dee83a38ab35fa6',1,'mlx::steel::Conv2DWeightBlockLoaderSmallChannels::Conv2DWeightBlockLoaderSmallChannels()']]],
|
||||
['conv3d_95',['conv3d',['../group__ops.html#ga6e9907d2f14dc4803e4306b3dbc4b3ca',1,'mlx::core']]],
|
||||
['conv_5fgeneral_96',['conv_general',['../group__ops.html#ga2236e5dfc7e52e28abf6c21675d0a51e',1,'mlx::core::conv_general(array input, array weight, std::vector< int > stride={}, std::vector< int > padding_lo={}, std::vector< int > padding_hi={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})'],['../group__ops.html#gab59f89942cd1efaadffe9e8762e3c99d',1,'mlx::core::conv_general(const array &input, const array &weight, std::vector< int > stride={}, std::vector< int > padding={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})']]],
|
||||
['conv_5ftranspose1d_97',['conv_transpose1d',['../group__ops.html#gaa30bf1adcd78d1c2595d07b215731714',1,'mlx::core']]],
|
||||
['conv_5ftranspose2d_98',['conv_transpose2d',['../group__ops.html#gaebb59971cb9bc45005dc1d398e4f0a3d',1,'mlx::core']]],
|
||||
['conv_5ftranspose3d_99',['conv_transpose3d',['../group__ops.html#ga8db814da631d9cd32a8d6563bf4ac530',1,'mlx::core']]],
|
||||
['convolution_100',['Convolution',['../classmlx_1_1core_1_1_convolution.html',1,'mlx::core::Convolution'],['../classmlx_1_1core_1_1_convolution.html#a6f1de77b719bb13217b0d8c64cabb8ef',1,'mlx::core::Convolution::Convolution()']]],
|
||||
['copy_101',['Copy',['../classmlx_1_1core_1_1_copy.html',1,'mlx::core::Copy'],['../classmlx_1_1core_1_1_copy.html#a6243e044af119105ffaaed7d405cd584',1,'mlx::core::Copy::Copy()']]],
|
||||
['copy_102',['copy',['../namespacemlx_1_1core.html#a479648542a2bea151b947b18f0e79dd2',1,'mlx::core::copy()'],['../namespacemlx_1_1core_1_1metal.html#aa215e631e2680f04a591b88d91571719',1,'mlx::core::metal::copy()'],['../group__ops.html#gae306e93af12f774bd80bad6c231b09d6',1,'mlx::core::copy()']]],
|
||||
['copy_2eh_103',['copy.h',['../common_2copy_8h.html',1,'(Global Namespace)'],['../metal_2copy_8h.html',1,'(Global Namespace)'],['../metal_2kernels_2copy_8h.html',1,'(Global Namespace)']]],
|
||||
['copy_5fg_104',['copy_g',['../metal_2kernels_2copy_8h.html#a778ce2dbfbaa23b24bd5efbe68448c36',1,'copy.h']]],
|
||||
['copy_5fg_5fnd1_105',['copy_g_nd1',['../metal_2kernels_2copy_8h.html#aba4530a7db6a61ca36f50e4f5e58fb77',1,'copy.h']]],
|
||||
['copy_5fg_5fnd2_106',['copy_g_nd2',['../metal_2kernels_2copy_8h.html#aee678c7c31119f3e609685589f37490c',1,'copy.h']]],
|
||||
['copy_5fg_5fnd3_107',['copy_g_nd3',['../metal_2kernels_2copy_8h.html#a821f8f3f3891159a295c66fc25aed1ff',1,'copy.h']]],
|
||||
['copy_5fgg_108',['copy_gg',['../metal_2kernels_2copy_8h.html#a1e39c2683eeaf05955e7619fbd34aea5',1,'copy.h']]],
|
||||
['copy_5fgg_5fnd1_109',['copy_gg_nd1',['../metal_2kernels_2copy_8h.html#a3278d9c999718bee3ccbe2922f501bf1',1,'copy.h']]],
|
||||
['copy_5fgg_5fnd2_110',['copy_gg_nd2',['../metal_2kernels_2copy_8h.html#a3e2d3cc7f34f56170409b6735f51a950',1,'copy.h']]],
|
||||
['copy_5fgg_5fnd3_111',['copy_gg_nd3',['../metal_2kernels_2copy_8h.html#a59f43b5bffed936d7559ceb06a10aabd',1,'copy.h']]],
|
||||
['copy_5fgpu_112',['copy_gpu',['../namespacemlx_1_1core.html#addaa46a13ac2deb1d9ce621338320e0e',1,'mlx::core::copy_gpu(const array &src, array &out, CopyType ctype, const Stream &s)'],['../namespacemlx_1_1core.html#a6a6f4e46c8fc44fdc74c50ace02bcf38',1,'mlx::core::copy_gpu(const array &src, array &out, CopyType ctype)']]],
|
||||
['copy_5fgpu_5finplace_113',['copy_gpu_inplace',['../namespacemlx_1_1core.html#a69e30f5d30a6d72ac0ffe4886f24b7ba',1,'mlx::core::copy_gpu_inplace(const array &in, array &out, const std::vector< int > &data_shape, const std::vector< stride_t > &i_strides, const std::vector< stride_t > &o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype, const Stream &s)'],['../namespacemlx_1_1core.html#a8e1ccb0ed9387b0a789311d9f8964803',1,'mlx::core::copy_gpu_inplace(const array &src, array &out, CopyType ctype, const Stream &s)'],['../namespacemlx_1_1core.html#ae55b801b09ccf55cba96278163a9b1ef',1,'mlx::core::copy_gpu_inplace(const array &in, array &out, const std::vector< int64_t > &istride, int64_t ioffset, CopyType ctype, const Stream &s)']]],
|
||||
['copy_5fhartley_114',['copy_hartley',['../namespacepocketfft_1_1detail.html#abac3fcc8ce83800d228774f64c28d4c3',1,'pocketfft::detail::copy_hartley(const multi_iter< vlen > &it, const vtype_t< T > *src, ndarr< T > &dst)'],['../namespacepocketfft_1_1detail.html#ae7b44d2773d9d06a9787aff01d66b3ed',1,'pocketfft::detail::copy_hartley(const multi_iter< vlen > &it, const T *src, ndarr< T > &dst)']]],
|
||||
['copy_5finplace_115',['copy_inplace',['../namespacemlx_1_1core.html#a98495894a796b2cc6d022e7a03432c64',1,'mlx::core::copy_inplace(const array &src, array &dst, CopyType ctype)'],['../namespacemlx_1_1core.html#aad636e2d0b2f882cadd1b438f4daa9ed',1,'mlx::core::copy_inplace(const array &src, array &dst, const std::vector< int > &data_shape, const std::vector< stride_t > &i_strides, const std::vector< stride_t > &o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype)']]],
|
||||
['copy_5finput_116',['copy_input',['../namespacepocketfft_1_1detail.html#aff05be3064743c1143b19318ab12ad4a',1,'pocketfft::detail::copy_input(const multi_iter< vlen > &it, const cndarr< cmplx< T > > &src, cmplx< vtype_t< T > > *dst)'],['../namespacepocketfft_1_1detail.html#a30fc708f9d8f9cfa74194925c7863c0a',1,'pocketfft::detail::copy_input(const multi_iter< vlen > &it, const cndarr< T > &src, vtype_t< T > *dst)'],['../namespacepocketfft_1_1detail.html#a3387bd35f237870e42b8461769e6aec4',1,'pocketfft::detail::copy_input(const multi_iter< vlen > &it, const cndarr< T > &src, T *dst)']]],
|
||||
['copy_5foutput_117',['copy_output',['../namespacepocketfft_1_1detail.html#a1523a037300a8da05db210b802d9cb0e',1,'pocketfft::detail::copy_output(const multi_iter< vlen > &it, const cmplx< vtype_t< T > > *src, ndarr< cmplx< T > > &dst)'],['../namespacepocketfft_1_1detail.html#a21980853aca4d92ed06e3dcffe7ef660',1,'pocketfft::detail::copy_output(const multi_iter< vlen > &it, const vtype_t< T > *src, ndarr< T > &dst)'],['../namespacepocketfft_1_1detail.html#a310481c334e46674710ba794ad7403c0',1,'pocketfft::detail::copy_output(const multi_iter< vlen > &it, const T *src, ndarr< T > &dst)']]],
|
||||
['copy_5fs_118',['copy_s',['../metal_2kernels_2copy_8h.html#aef09f9b9475345b1bba121d037d222ea',1,'copy.h']]],
|
||||
['copy_5fs2_119',['copy_s2',['../metal_2kernels_2copy_8h.html#a8023e9335cc5334847a8d315042be3a3',1,'copy.h']]],
|
||||
['copy_5fshared_5fbuffer_120',['copy_shared_buffer',['../classmlx_1_1core_1_1array.html#a28df7a333d90a311c49bc4bce7a1ad6d',1,'mlx::core::array::copy_shared_buffer(const array &other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)'],['../classmlx_1_1core_1_1array.html#a92974c656c35a972ad241f80584bbd29',1,'mlx::core::array::copy_shared_buffer(const array &other)']]],
|
||||
['copy_5fv_121',['copy_v',['../metal_2kernels_2copy_8h.html#ae26a13e0c8e6c15f7b10078e65970659',1,'copy.h']]],
|
||||
['copy_5fv2_122',['copy_v2',['../metal_2kernels_2copy_8h.html#aee14a5326f53d9b30b0b38e27d180ef3',1,'copy.h']]],
|
||||
['copytype_123',['CopyType',['../namespacemlx_1_1core.html#abd84ff6c5245e4e170b2ef5247594337',1,'mlx::core']]],
|
||||
['core_20array_20operations_124',['Core array operations',['../group__ops.html',1,'']]],
|
||||
['cos_125',['Cos',['../struct_cos.html',1,'Cos'],['../classmlx_1_1core_1_1_cos.html',1,'mlx::core::Cos'],['../structmlx_1_1core_1_1detail_1_1_cos.html',1,'mlx::core::detail::Cos'],['../classmlx_1_1core_1_1_cos.html#a2acb9fcf0901462189c476756fd99995',1,'mlx::core::Cos::Cos()']]],
|
||||
['cos_126',['cos',['../namespacepocketfft_1_1detail.html#a499c1e8b7d79a5272af024f46c63ff9d',1,'pocketfft::detail::cos()'],['../namespacemetal.html#a2fa4778a6fe2fa43253ea724e5a608a3',1,'metal::cos()'],['../namespacemetal_1_1fast.html#a75b6bb32fa3870eda46a7bfc9f481f88',1,'metal::fast::cos()'],['../namespacemetal_1_1precise.html#ac4941f62e7d8ab9d7cabbd967aa9f220',1,'metal::precise::cos()'],['../group__ops.html#ga39dfdf72b556012aa35ff27a94116e74',1,'mlx::core::cos()']]],
|
||||
['cosh_127',['Cosh',['../struct_cosh.html',1,'Cosh'],['../classmlx_1_1core_1_1_cosh.html',1,'mlx::core::Cosh'],['../structmlx_1_1core_1_1detail_1_1_cosh.html',1,'mlx::core::detail::Cosh'],['../classmlx_1_1core_1_1_cosh.html#a44e8ac2e09a55ec32e9dc6641eedc8f1',1,'mlx::core::Cosh::Cosh()']]],
|
||||
['cosh_128',['cosh',['../namespacemetal.html#a8a68a88cc110830d057dbd71431b93c0',1,'metal::cosh()'],['../namespacemetal_1_1fast.html#a31544ad9de28012a4ddda86e3966a77e',1,'metal::fast::cosh()'],['../namespacemetal_1_1precise.html#a72d86d508300a9b58f4ccbbe70da4fbc',1,'metal::precise::cosh()'],['../group__ops.html#ga2181b71cda88007a3092be4795ff0715',1,'mlx::core::cosh()']]],
|
||||
['cosine_129',['cosine',['../structpocketfft_1_1detail_1_1_exec_dcst.html#a185023fc1e386cc8f233b79c49c1fd8a',1,'pocketfft::detail::ExecDcst']]],
|
||||
['cospi_130',['cospi',['../namespacemetal.html#a5c2f37939ad705ddea4409d3bedb8ce1',1,'metal::cospi()'],['../namespacemetal_1_1fast.html#a9906b41f75319b384ffb570cc94d67ce',1,'metal::fast::cospi()'],['../namespacemetal_1_1precise.html#a2392b78bd196efdbbac65901c4ab20e7',1,'metal::precise::cospi()']]],
|
||||
['cost_5fguess_131',['cost_guess',['../structpocketfft_1_1detail_1_1util.html#ad3d874bc3fb0048df2270779a15d4bd0',1,'pocketfft::detail::util']]],
|
||||
['count_5fdown_132',['count_down',['../classpocketfft_1_1detail_1_1threading_1_1latch.html#a81d6597189b40410e35f3cd653fd1342',1,'pocketfft::detail::threading::latch']]],
|
||||
['cpu_133',['cpu',['../structmlx_1_1core_1_1_device.html#a69ee81924251dec96f1945c9d91506fd',1,'mlx::core::Device::cpu'],['../structmlx_1_1core_1_1_device.html#ac45b3de9b3458d8f31005136cde20fdbad9747e2da342bdb995f6389533ad1a3d',1,'mlx::core::Device::cpu']]],
|
||||
['cross_134',['cross',['../namespacemlx_1_1core_1_1linalg.html#abcda3fbda45183c21e7f27aa0dde64e6',1,'mlx::core::linalg']]],
|
||||
['ctile_135',['Ctile',['../structmlx_1_1steel_1_1_block_m_m_a.html#a81838da5d81e62d372d581be599c5a88',1,'mlx::steel::BlockMMA']]],
|
||||
['cummax_136',['CumMax',['../struct_cum_max.html',1,'']]],
|
||||
['cummax_137',['cummax',['../group__ops.html#gaee37cac8476e8f8d666bcded5bc59143',1,'mlx::core']]],
|
||||
['cummin_138',['CumMin',['../struct_cum_min.html',1,'']]],
|
||||
['cummin_139',['cummin',['../group__ops.html#ga19c1bf6929fe8d66b9cd408946aea6a8',1,'mlx::core']]],
|
||||
['cumprod_140',['CumProd',['../struct_cum_prod.html',1,'']]],
|
||||
['cumprod_141',['cumprod',['../group__ops.html#ga0d71dfbc14ef3ed564b0c5ee26af680f',1,'mlx::core']]],
|
||||
['cumprod_3c_20bool_20_3e_142',['CumProd< bool >',['../struct_cum_prod_3_01bool_01_4.html',1,'']]],
|
||||
['cumsum_143',['CumSum',['../struct_cum_sum.html',1,'']]],
|
||||
['cumsum_144',['cumsum',['../group__ops.html#gaddc825a5c173e195ab0fda83ad630420',1,'mlx::core']]],
|
||||
['custom_145',['Custom',['../classmlx_1_1core_1_1fast_1_1_custom.html',1,'mlx::core::fast::Custom'],['../classmlx_1_1core_1_1fast_1_1_custom.html#a4186fea23f7156c38960426821fca313',1,'mlx::core::fast::Custom::Custom()']]],
|
||||
['custom_5ffunction_146',['custom_function',['../namespacemlx_1_1core.html#a8d3ca5fbaecdb995660c24cde5aeebaf',1,'mlx::core']]],
|
||||
['custom_5fvjp_147',['custom_vjp',['../namespacemlx_1_1core.html#a9290596250fa308df4c69b44483bb8aa',1,'mlx::core']]],
|
||||
['customkernel_148',['CustomKernel',['../classmlx_1_1core_1_1fast_1_1_custom_kernel.html',1,'mlx::core::fast::CustomKernel'],['../classmlx_1_1core_1_1fast_1_1_custom_kernel.html#a954893e07f0d36715b4e1e414b6f2153',1,'mlx::core::fast::CustomKernel::CustomKernel()']]],
|
||||
['customkernelshapeinfo_149',['CustomKernelShapeInfo',['../structmlx_1_1core_1_1fast_1_1_custom_kernel_shape_info.html',1,'mlx::core::fast']]],
|
||||
['customtransforms_150',['CustomTransforms',['../classmlx_1_1core_1_1_custom_transforms.html',1,'mlx::core::CustomTransforms'],['../classmlx_1_1core_1_1_custom_transforms.html#ab52abadb9c6f6db83d087c7b751be488',1,'mlx::core::CustomTransforms::CustomTransforms()']]]
|
||||
];
|
||||
|
2
docs/build/html/search/all_7.js
vendored
2
docs/build/html/search/all_7.js
vendored
@ -67,7 +67,7 @@ var searchData=
|
||||
['get_5fpool_64',['get_pool',['../namespacepocketfft_1_1detail_1_1threading.html#a7ec2b3f99232bd0f15f7b022c59d139a',1,'pocketfft::detail::threading']]],
|
||||
['get_5fprimitive_5fstring_65',['get_primitive_string',['../namespacemlx_1_1core.html#ad4be35b310a252edd80d9cf04f094a60',1,'mlx::core']]],
|
||||
['get_5fquantized_5fkernel_66',['get_quantized_kernel',['../namespacemlx_1_1core.html#aa3faeae5378bfaafe3ce3432a051e43e',1,'mlx::core']]],
|
||||
['get_5freduce_5finit_5fkernel_67',['get_reduce_init_kernel',['../namespacemlx_1_1core.html#a51c4bb09230348bd0252e22bfdc9bc89',1,'mlx::core']]],
|
||||
['get_5freduce_5finit_5fkernel_67',['get_reduce_init_kernel',['../namespacemlx_1_1core.html#a3bd386cb6db09f636963ce66ceaf8647',1,'mlx::core']]],
|
||||
['get_5freduce_5fkernel_68',['get_reduce_kernel',['../namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b',1,'mlx::core']]],
|
||||
['get_5freduction_5fplan_69',['get_reduction_plan',['../namespacemlx_1_1core.html#ac97b5a6f009ca3d99854ce9512c20dba',1,'mlx::core']]],
|
||||
['get_5fscan_5fkernel_70',['get_scan_kernel',['../namespacemlx_1_1core.html#aeefaff208444d3fa61ecc0946fe1de5f',1,'mlx::core']]],
|
||||
|
161
docs/build/html/search/all_d.js
vendored
161
docs/build/html/search/all_d.js
vendored
@ -29,84 +29,85 @@ var searchData=
|
||||
['max_5fthreads_26',['max_threads',['../namespacepocketfft_1_1detail_1_1threading.html#a2d5c0729f0b66cf061918baea4337d70',1,'pocketfft::detail::threading']]],
|
||||
['maximum_27',['Maximum',['../struct_maximum.html',1,'Maximum'],['../structmlx_1_1core_1_1detail_1_1_maximum.html',1,'mlx::core::detail::Maximum'],['../classmlx_1_1core_1_1_maximum.html',1,'mlx::core::Maximum'],['../classmlx_1_1core_1_1_maximum.html#a28389307e385efe1b2955b86b115e816',1,'mlx::core::Maximum::Maximum()']]],
|
||||
['maximum_28',['maximum',['../group__ops.html#ga7ade2ea305e2e4219c3609443fb5db8d',1,'mlx::core']]],
|
||||
['mb_5fblock_5fmerge_29',['mb_block_merge',['../sort_8h.html#ab381cd57f344bc7304ab580bfdc78807',1,'sort.h']]],
|
||||
['mb_5fblock_5fpartition_30',['mb_block_partition',['../sort_8h.html#a32cbe4163b8b0f5cb2c97b256119a4b2',1,'sort.h']]],
|
||||
['mb_5fblock_5fsort_31',['mb_block_sort',['../sort_8h.html#aa48ff1aff1e9dc1301b6781aa0721d6b',1,'sort.h']]],
|
||||
['mean_32',['mean',['../group__ops.html#gade46e768fd46b8b640eb16f26abeecef',1,'mlx::core::mean(const array &a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga52b59fdd8e8430538e564f5bbcfa31e6',1,'mlx::core::mean(const array &a, StreamOrDevice s={})'],['../group__ops.html#ga066161f3d3e395a1d76c638cb680d444',1,'mlx::core::mean(const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#ga45fba73eab0e3b6e128ed3ce2f43a5da',1,'mlx::core::mean(const array &a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
|
||||
['median3_33',['median3',['../namespacemetal.html#aa3ff49457ce3c93fc1c0897fd1525157',1,'metal::median3()'],['../namespacemetal_1_1fast.html#a742b55f1e4369921ee7f60d70185bfbc',1,'metal::fast::median3()'],['../namespacemetal_1_1precise.html#a14555ff99c4388493fec48e070144ae2',1,'metal::precise::median3()']]],
|
||||
['merge_5fpartition_34',['merge_partition',['../struct_block_merge_sort.html#ab2300cbecb23f3433bad888924c831ca',1,'BlockMergeSort::merge_partition()'],['../struct_kernel_multi_block_merge_sort.html#ab15895b4233aba0e279cc44a07a201fe',1,'KernelMultiBlockMergeSort::merge_partition()']]],
|
||||
['merge_5fstep_35',['merge_step',['../struct_block_merge_sort.html#ab65f190edf1851b37c39ad49ce99a43c',1,'BlockMergeSort']]],
|
||||
['meshgrid_36',['meshgrid',['../group__ops.html#ga577c911618575314de63d1060656a26e',1,'mlx::core']]],
|
||||
['metal_37',['metal',['../namespacemetal.html',1,'']]],
|
||||
['metal_2eh_38',['metal.h',['../metal_8h.html',1,'']]],
|
||||
['metal_3a_3afast_39',['fast',['../namespacemetal_1_1fast.html',1,'metal']]],
|
||||
['metal_3a_3aprecise_40',['precise',['../namespacemetal_1_1precise.html',1,'metal']]],
|
||||
['metal_5fimpl_2eh_41',['metal_impl.h',['../metal__impl_8h.html',1,'']]],
|
||||
['metal_5fkernel_42',['metal_kernel',['../namespacemlx_1_1core_1_1fast.html#ab16436b465dc10ce472193d541d8426e',1,'mlx::core::fast']]],
|
||||
['metalallocator_43',['MetalAllocator',['../classmlx_1_1core_1_1metal_1_1_metal_allocator.html',1,'mlx::core::metal']]],
|
||||
['metalkernelfunction_44',['MetalKernelFunction',['../namespacemlx_1_1core_1_1fast.html#a0e8c2c4ea7a946568c8fe5b4810417e0',1,'mlx::core::fast']]],
|
||||
['min_45',['Min',['../struct_min.html',1,'Min< U >'],['../classmlx_1_1core_1_1distributed_1_1_all_reduce.html#abb4560980e5d01aed14175ce8f6fc924a4f685dcd48e6614d6bb2ccda4f2686ef',1,'mlx::core::distributed::AllReduce::Min'],['../classmlx_1_1core_1_1_reduce.html#a0848518b16ae6d4043d6be247bdf31c9a0d3d1f5c94725bdc42fa692e2c074418',1,'mlx::core::Reduce::Min'],['../classmlx_1_1core_1_1_scan.html#a47bf2ec54ead4b8f00f9f188518630f1a7d2ee8f14f2e70a9d47170fecc6da898',1,'mlx::core::Scan::Min'],['../classmlx_1_1core_1_1_scatter.html#a614d19af11dc30644b2b4941033b613cad914e4c3475ce9858f2de4bf35dcfdbf',1,'mlx::core::Scatter::Min']]],
|
||||
['min_46',['min',['../struct_limits.html#a6e81584ba65a4dc6ff9366b458e3a20e',1,'Limits::min'],['../struct_limits_3_01uint8__t_01_4.html#a408bd5a337e7292f06e63da81193629a',1,'Limits< uint8_t >::min'],['../struct_limits_3_01uint16__t_01_4.html#ae173984c3be8b6750f27daed581805fe',1,'Limits< uint16_t >::min'],['../struct_limits_3_01uint32__t_01_4.html#ab0c3975e02053b234c7b606ababa66e1',1,'Limits< uint32_t >::min'],['../struct_limits_3_01uint64__t_01_4.html#a80627f39e951398283942cefa48f4dd0',1,'Limits< uint64_t >::min'],['../struct_limits_3_01int8__t_01_4.html#a7a809307d2bba80382f0645d277eaa4b',1,'Limits< int8_t >::min'],['../struct_limits_3_01int16__t_01_4.html#adca7139647801e223c35b0abc7da5240',1,'Limits< int16_t >::min'],['../struct_limits_3_01int32__t_01_4.html#af336a1b22a8ed6a83a4cfb5bf8869771',1,'Limits< int32_t >::min'],['../struct_limits_3_01int64__t_01_4.html#a1c90fb96af515badaccaa835b08f7428',1,'Limits< int64_t >::min'],['../struct_limits_3_01half_01_4.html#aca7b036c257878bf1b80912fb5d4516d',1,'Limits< half >::min'],['../struct_limits_3_01float_01_4.html#a3225e334d372ee86128c89a440d8648f',1,'Limits< float >::min'],['../struct_limits_3_01bfloat16__t_01_4.html#a2fd1811b9f615b2b897904bc27d1cb49',1,'Limits< bfloat16_t >::min'],['../struct_limits_3_01bool_01_4.html#a139f787b57536d455490b8ef801d37cc',1,'Limits< bool >::min'],['../struct_limits_3_01complex64__t_01_4.html#aa67b04aa7abcd67f7af0808737ab8e14',1,'Limits< complex64_t >::min'],['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#adaed80031f5ca0ff69d30ec4c5d0c98f',1,'metal::_numeric_limits_impl< bfloat16_t >::min()'],['../namespacemetal.html#a6653b28c9473087141eddce39878d4d3',1,'metal::min()'],['../namespacemetal_1_1fast.html#a3e958e56a4712687c381a0b64d123e61',1,'metal::fast::min()'],['../namespacemetal_1_1precise.html#afed0da2f7df3505b5dffa2389c3cb36e',1,'metal::precise::min()'],['../group__ops.html#gab27599802617a4c8f9964ab5f4ffee12',1,'mlx::core::min(const array &a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga0140b91e9cdfc3fef0da8e332f65a9e8',1,'mlx::core::min(const array &a, StreamOrDevice s={})'],['../group__ops.html#ga6efb83cd46436678c8f8c4af15cc00f5',1,'mlx::core::min(const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#ga36fa315eef677f4143868f552cd26d03',1,'mlx::core::min(const array &a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
|
||||
['min3_47',['min3',['../namespacemetal.html#a005510c8c0f964ce2b8aad3ba76a7a3f',1,'metal::min3()'],['../namespacemetal_1_1fast.html#a606a4c1b34ce05ea89ca5af81724036f',1,'metal::fast::min3()'],['../namespacemetal_1_1precise.html#a4d37ce31c3549ca4772a4ee29798e231',1,'metal::precise::min3()']]],
|
||||
['min_5fexponent_48',['min_exponent',['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#a13829f8c7a7c0efdc8946eff5d3c9470',1,'metal::_numeric_limits_impl< bfloat16_t >']]],
|
||||
['min_5fexponent10_49',['min_exponent10',['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#aeaed172780720e06b8731cef3177e277',1,'metal::_numeric_limits_impl< bfloat16_t >']]],
|
||||
['minimum_50',['Minimum',['../struct_minimum.html',1,'Minimum'],['../structmlx_1_1core_1_1detail_1_1_minimum.html',1,'mlx::core::detail::Minimum'],['../classmlx_1_1core_1_1_minimum.html',1,'mlx::core::Minimum'],['../classmlx_1_1core_1_1_minimum.html#ab0f2ce17108df44b82cff68886b0f6f5',1,'mlx::core::Minimum::Minimum()']]],
|
||||
['minimum_51',['minimum',['../group__ops.html#ga49ba00c090f81f331c91b0c97040bce0',1,'mlx::core']]],
|
||||
['mlx_52',['mlx',['../namespacemlx.html',1,'']]],
|
||||
['mlx_2eh_53',['mlx.h',['../mlx_8h.html',1,'']]],
|
||||
['mlx_3a_3acore_54',['core',['../namespacemlx_1_1core.html',1,'mlx']]],
|
||||
['mlx_3a_3acore_3a_3aallocator_55',['allocator',['../namespacemlx_1_1core_1_1allocator.html',1,'mlx::core']]],
|
||||
['mlx_3a_3acore_3a_3adetail_56',['detail',['../namespacemlx_1_1core_1_1detail.html',1,'mlx::core']]],
|
||||
['mlx_3a_3acore_3a_3adistributed_57',['distributed',['../namespacemlx_1_1core_1_1distributed.html',1,'mlx::core']]],
|
||||
['mlx_3a_3acore_3a_3adistributed_3a_3adetail_58',['detail',['../namespacemlx_1_1core_1_1distributed_1_1detail.html',1,'mlx::core::distributed']]],
|
||||
['mlx_3a_3acore_3a_3afast_59',['fast',['../namespacemlx_1_1core_1_1fast.html',1,'mlx::core']]],
|
||||
['mlx_3a_3acore_3a_3afft_60',['fft',['../namespacemlx_1_1core_1_1fft.html',1,'mlx::core']]],
|
||||
['mlx_3a_3acore_3a_3aio_61',['io',['../namespacemlx_1_1core_1_1io.html',1,'mlx::core']]],
|
||||
['mlx_3a_3acore_3a_3alinalg_62',['linalg',['../namespacemlx_1_1core_1_1linalg.html',1,'mlx::core']]],
|
||||
['mlx_3a_3acore_3a_3ametal_63',['metal',['../namespacemlx_1_1core_1_1metal.html',1,'mlx::core']]],
|
||||
['mlx_3a_3acore_3a_3arandom_64',['random',['../namespacemlx_1_1core_1_1random.html',1,'mlx::core']]],
|
||||
['mlx_3a_3acore_3a_3ascheduler_65',['scheduler',['../namespacemlx_1_1core_1_1scheduler.html',1,'mlx::core']]],
|
||||
['mlx_3a_3asteel_66',['steel',['../namespacemlx_1_1steel.html',1,'mlx']]],
|
||||
['mlx_5fatomic_67',['mlx_atomic',['../structmlx__atomic.html',1,'']]],
|
||||
['mlx_5fatomic_3c_20t_2c_20enable_5fif_5ft_3c_20is_5fmetal_5fatomic_3c_20t_20_3e_20_3e_20_3e_68',['mlx_atomic< T, enable_if_t< is_metal_atomic< T > > >',['../structmlx__atomic_3_01_t_00_01enable__if__t_3_01is__metal__atomic_3_01_t_01_4_01_4_01_4.html',1,'']]],
|
||||
['mlx_5fatomic_5fcompare_5fexchange_5fweak_5fexplicit_69',['mlx_atomic_compare_exchange_weak_explicit',['../atomic_8h.html#ad7f32327ff66354cfa2f0cfdac79316f',1,'mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic< T > *object, thread T *expected, T val, size_t offset): atomic.h'],['../atomic_8h.html#aa8f47b2e9b95d4b00ad51f08b070deb5',1,'mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic< T > *object, thread uint *expected, uint val, size_t offset): atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fadd_5fexplicit_70',['mlx_atomic_fetch_add_explicit',['../atomic_8h.html#aad448d9e06e001700b65ca8317216a3b',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fand_5fexplicit_71',['mlx_atomic_fetch_and_explicit',['../atomic_8h.html#a253e3c870c0ddc7c28ab2f6ca2c3eae5',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmax_5fexplicit_72',['mlx_atomic_fetch_max_explicit',['../atomic_8h.html#ac480f2b459a8ad9095cee353e152d00c',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmax_5fexplicit_3c_20float_20_3e_73',['mlx_atomic_fetch_max_explicit< float >',['../atomic_8h.html#a1dce2abfa16417122c4d2bf261129ae4',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmin_5fexplicit_74',['mlx_atomic_fetch_min_explicit',['../atomic_8h.html#a2ec33dca0039bd944d73d1c2b378cc19',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmin_5fexplicit_3c_20float_20_3e_75',['mlx_atomic_fetch_min_explicit< float >',['../atomic_8h.html#ab7d1dc49f319f239b7ee0b7c72976dd0',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmul_5fexplicit_76',['mlx_atomic_fetch_mul_explicit',['../atomic_8h.html#adfdbea60436f14f1af9ce36e2a0a77a3',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5for_5fexplicit_77',['mlx_atomic_fetch_or_explicit',['../atomic_8h.html#ab7391f197001471e4788312bdb6ab37a',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5fload_5fexplicit_78',['mlx_atomic_load_explicit',['../atomic_8h.html#a253a4e8c2c5768a069e2791b627dfc99',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5fstore_5fexplicit_79',['mlx_atomic_store_explicit',['../atomic_8h.html#a0ae453140b0819a4c02f265334de98c0',1,'atomic.h']]],
|
||||
['mlx_5flapack_5ffunc_80',['MLX_LAPACK_FUNC',['../lapack_8h.html#ae22db9704827bf013a0a61f21a47464b',1,'lapack.h']]],
|
||||
['mlx_5fmtl_5fconst_81',['MLX_MTL_CONST',['../kernels_2gemv__masked_8h.html#a0386011c52d03e60885a31e6fbd903dd',1,'MLX_MTL_CONST: gemv_masked.h'],['../quantized_8h.html#a0386011c52d03e60885a31e6fbd903dd',1,'MLX_MTL_CONST: quantized.h'],['../sort_8h.html#a0386011c52d03e60885a31e6fbd903dd',1,'MLX_MTL_CONST: sort.h']]],
|
||||
['mlx_5fmtl_5floop_5funroll_82',['MLX_MTL_LOOP_UNROLL',['../sort_8h.html#ad34b622323cebef136669fedd7229515',1,'sort.h']]],
|
||||
['mlx_5fmtl_5fpragma_5funroll_83',['MLX_MTL_PRAGMA_UNROLL',['../kernels_2gemv__masked_8h.html#a069b682d7d21827461544817d722bfd3',1,'MLX_MTL_PRAGMA_UNROLL: gemv_masked.h'],['../backend_2metal_2kernels_2utils_8h.html#a069b682d7d21827461544817d722bfd3',1,'MLX_MTL_PRAGMA_UNROLL: utils.h']]],
|
||||
['mlxconvparams_84',['MLXConvParams',['../struct_m_l_x_conv_params.html',1,'']]],
|
||||
['mlxconvparams_3c_202_20_3e_85',['MLXConvParams< 2 >',['../struct_m_l_x_conv_params.html',1,'']]],
|
||||
['mlxfastattentionparams_86',['MLXFastAttentionParams',['../struct_m_l_x_fast_attention_params.html',1,'']]],
|
||||
['mlxscaleddotproductattentionparams_87',['MLXScaledDotProductAttentionParams',['../struct_m_l_x_scaled_dot_product_attention_params.html',1,'']]],
|
||||
['mma_88',['mma',['../structmlx_1_1steel_1_1_base_m_m_a_frag_3_01_t_00_018_00_018_01_4.html#a8028512f5a3d2b6acaf966be529627a3',1,'mlx::steel::BaseMMAFrag< T, 8, 8 >::mma(thread frag_type &D, thread frag_type &A, thread frag_type &B, thread frag_type &C)'],['../structmlx_1_1steel_1_1_base_m_m_a_frag_3_01_t_00_018_00_018_01_4.html#a1868f57d57c8adedab2c58492ec76946',1,'mlx::steel::BaseMMAFrag< T, 8, 8 >::mma(thread mat_type &D, thread mat_type &A, thread mat_type &B, thread mat_type &C)'],['../structmlx_1_1steel_1_1_block_m_m_a.html#a6a2c2a6d5e767d52c41b42a9d36086b0',1,'mlx::steel::BlockMMA::mma()']]],
|
||||
['mma_2eh_89',['mma.h',['../mma_8h.html',1,'']]],
|
||||
['mma_5ft_90',['mma_t',['../structmlx_1_1steel_1_1_g_e_m_m_kernel.html#add8c6a31011a4895667c2a94a5af3782',1,'mlx::steel::GEMMKernel']]],
|
||||
['mmafrag_5facc_5ft_91',['MMAFrag_acc_t',['../structmlx_1_1steel_1_1_block_m_m_a.html#ae2c42cb6d0dde785859164c195f4d13c',1,'mlx::steel::BlockMMA']]],
|
||||
['mmafrag_5ft_92',['MMAFrag_t',['../structmlx_1_1steel_1_1_m_m_a_tile.html#abe33de70e34300745bad9aa822fd0382',1,'mlx::steel::MMATile']]],
|
||||
['mmatile_93',['MMATile',['../structmlx_1_1steel_1_1_m_m_a_tile.html',1,'mlx::steel::MMATile< T, kTileRows_, kTileCols_, MMAFrag_ >'],['../structmlx_1_1steel_1_1_m_m_a_tile.html#aa3fb310dd08ec23c334511f7b316d1b6',1,'mlx::steel::MMATile::MMATile()']]],
|
||||
['mmatile_3c_20float_2c_201_2c_20tn_2c_20mlx_3a_3asteel_3a_3abasemmafrag_20_3e_94',['MMATile< float, 1, TN, mlx::steel::BaseMMAFrag >',['../structmlx_1_1steel_1_1_m_m_a_tile.html',1,'mlx::steel']]],
|
||||
['mmatile_3c_20float_2c_20tm_2c_201_2c_20mlx_3a_3asteel_3a_3abasemmafrag_20_3e_95',['MMATile< float, TM, 1, mlx::steel::BaseMMAFrag >',['../structmlx_1_1steel_1_1_m_m_a_tile.html',1,'mlx::steel']]],
|
||||
['mmatile_3c_20float_2c_20tm_2c_20tn_2c_20mlx_3a_3asteel_3a_3abasemmafrag_20_3e_96',['MMATile< float, TM, TN, mlx::steel::BaseMMAFrag >',['../structmlx_1_1steel_1_1_m_m_a_tile.html',1,'mlx::steel']]],
|
||||
['move_5fshared_5fbuffer_97',['move_shared_buffer',['../classmlx_1_1core_1_1array.html#acce00db63e0f3d80f797b02397ade836',1,'mlx::core::array::move_shared_buffer(array other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)'],['../classmlx_1_1core_1_1array.html#a38d7ad605f8282e5e49d0c09e0555c78',1,'mlx::core::array::move_shared_buffer(array other)']]],
|
||||
['moveaxis_98',['moveaxis',['../group__ops.html#ga24067d10a842db2c9d509ea48135a2c3',1,'mlx::core']]],
|
||||
['mpinplace_99',['MPINPLACE',['../namespacepocketfft_1_1detail.html#af5eedf3cdfc83c0a30807092c39a9ce2',1,'pocketfft::detail']]],
|
||||
['mtl_5fconst_100',['MTL_CONST',['../defines_8h.html#a767ed9f2604de22b259cee02c4ce1d22',1,'defines.h']]],
|
||||
['mtl_5fdevice_101',['mtl_device',['../classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653',1,'mlx::core::metal::Device']]],
|
||||
['mtl_5fresidency_5fset_102',['mtl_residency_set',['../classmlx_1_1core_1_1metal_1_1_residency_set.html#ac4bfe5ef5e2eaebc458a1ed1953d15e9',1,'mlx::core::metal::ResidencySet']]],
|
||||
['mtlfclist_103',['MTLFCList',['../namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54',1,'mlx::core::metal']]],
|
||||
['mtx_104',['mtx',['../structmlx_1_1core_1_1scheduler_1_1_stream_thread.html#a70410c9e612f871663929f1e8441a976',1,'mlx::core::scheduler::StreamThread']]],
|
||||
['multi_5fiter_105',['multi_iter',['../classpocketfft_1_1detail_1_1multi__iter.html',1,'pocketfft::detail::multi_iter< N >'],['../classpocketfft_1_1detail_1_1multi__iter.html#a9be43bb18840202da6d17988fccc64b9',1,'pocketfft::detail::multi_iter::multi_iter()']]],
|
||||
['multiply_106',['Multiply',['../structmlx_1_1core_1_1detail_1_1_multiply.html',1,'mlx::core::detail::Multiply'],['../classmlx_1_1core_1_1_multiply.html',1,'mlx::core::Multiply'],['../struct_multiply.html',1,'Multiply'],['../classmlx_1_1core_1_1_multiply.html#aca5c50f900321f3eb4d6fbcbc225c00c',1,'mlx::core::Multiply::Multiply()']]],
|
||||
['multiply_107',['multiply',['../group__ops.html#gaf57392e641640b5d06e4c99518391c38',1,'mlx::core']]],
|
||||
['multivariate_5fnormal_108',['multivariate_normal',['../namespacemlx_1_1core_1_1random.html#a8c37da3c1c0c561cad7499d6d9db81fb',1,'mlx::core::random']]]
|
||||
['maybeinsertbarrier_29',['maybeInsertBarrier',['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#ad538ae88f90560063f9ba502e2795991',1,'mlx::core::metal::CommandEncoder']]],
|
||||
['mb_5fblock_5fmerge_30',['mb_block_merge',['../sort_8h.html#ab381cd57f344bc7304ab580bfdc78807',1,'sort.h']]],
|
||||
['mb_5fblock_5fpartition_31',['mb_block_partition',['../sort_8h.html#a32cbe4163b8b0f5cb2c97b256119a4b2',1,'sort.h']]],
|
||||
['mb_5fblock_5fsort_32',['mb_block_sort',['../sort_8h.html#aa48ff1aff1e9dc1301b6781aa0721d6b',1,'sort.h']]],
|
||||
['mean_33',['mean',['../group__ops.html#gade46e768fd46b8b640eb16f26abeecef',1,'mlx::core::mean(const array &a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga52b59fdd8e8430538e564f5bbcfa31e6',1,'mlx::core::mean(const array &a, StreamOrDevice s={})'],['../group__ops.html#ga066161f3d3e395a1d76c638cb680d444',1,'mlx::core::mean(const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#ga45fba73eab0e3b6e128ed3ce2f43a5da',1,'mlx::core::mean(const array &a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
|
||||
['median3_34',['median3',['../namespacemetal.html#aa3ff49457ce3c93fc1c0897fd1525157',1,'metal::median3()'],['../namespacemetal_1_1fast.html#a742b55f1e4369921ee7f60d70185bfbc',1,'metal::fast::median3()'],['../namespacemetal_1_1precise.html#a14555ff99c4388493fec48e070144ae2',1,'metal::precise::median3()']]],
|
||||
['merge_5fpartition_35',['merge_partition',['../struct_block_merge_sort.html#ab2300cbecb23f3433bad888924c831ca',1,'BlockMergeSort::merge_partition()'],['../struct_kernel_multi_block_merge_sort.html#ab15895b4233aba0e279cc44a07a201fe',1,'KernelMultiBlockMergeSort::merge_partition()']]],
|
||||
['merge_5fstep_36',['merge_step',['../struct_block_merge_sort.html#ab65f190edf1851b37c39ad49ce99a43c',1,'BlockMergeSort']]],
|
||||
['meshgrid_37',['meshgrid',['../group__ops.html#ga577c911618575314de63d1060656a26e',1,'mlx::core']]],
|
||||
['metal_38',['metal',['../namespacemetal.html',1,'']]],
|
||||
['metal_2eh_39',['metal.h',['../metal_8h.html',1,'']]],
|
||||
['metal_3a_3afast_40',['fast',['../namespacemetal_1_1fast.html',1,'metal']]],
|
||||
['metal_3a_3aprecise_41',['precise',['../namespacemetal_1_1precise.html',1,'metal']]],
|
||||
['metal_5fimpl_2eh_42',['metal_impl.h',['../metal__impl_8h.html',1,'']]],
|
||||
['metal_5fkernel_43',['metal_kernel',['../namespacemlx_1_1core_1_1fast.html#ab16436b465dc10ce472193d541d8426e',1,'mlx::core::fast']]],
|
||||
['metalallocator_44',['MetalAllocator',['../classmlx_1_1core_1_1metal_1_1_metal_allocator.html',1,'mlx::core::metal']]],
|
||||
['metalkernelfunction_45',['MetalKernelFunction',['../namespacemlx_1_1core_1_1fast.html#a0e8c2c4ea7a946568c8fe5b4810417e0',1,'mlx::core::fast']]],
|
||||
['min_46',['Min',['../struct_min.html',1,'Min< U >'],['../classmlx_1_1core_1_1distributed_1_1_all_reduce.html#abb4560980e5d01aed14175ce8f6fc924a4f685dcd48e6614d6bb2ccda4f2686ef',1,'mlx::core::distributed::AllReduce::Min'],['../classmlx_1_1core_1_1_reduce.html#a0848518b16ae6d4043d6be247bdf31c9a0d3d1f5c94725bdc42fa692e2c074418',1,'mlx::core::Reduce::Min'],['../classmlx_1_1core_1_1_scan.html#a47bf2ec54ead4b8f00f9f188518630f1a7d2ee8f14f2e70a9d47170fecc6da898',1,'mlx::core::Scan::Min'],['../classmlx_1_1core_1_1_scatter.html#a614d19af11dc30644b2b4941033b613cad914e4c3475ce9858f2de4bf35dcfdbf',1,'mlx::core::Scatter::Min']]],
|
||||
['min_47',['min',['../struct_limits.html#a6e81584ba65a4dc6ff9366b458e3a20e',1,'Limits::min'],['../struct_limits_3_01uint8__t_01_4.html#a408bd5a337e7292f06e63da81193629a',1,'Limits< uint8_t >::min'],['../struct_limits_3_01uint16__t_01_4.html#ae173984c3be8b6750f27daed581805fe',1,'Limits< uint16_t >::min'],['../struct_limits_3_01uint32__t_01_4.html#ab0c3975e02053b234c7b606ababa66e1',1,'Limits< uint32_t >::min'],['../struct_limits_3_01uint64__t_01_4.html#a80627f39e951398283942cefa48f4dd0',1,'Limits< uint64_t >::min'],['../struct_limits_3_01int8__t_01_4.html#a7a809307d2bba80382f0645d277eaa4b',1,'Limits< int8_t >::min'],['../struct_limits_3_01int16__t_01_4.html#adca7139647801e223c35b0abc7da5240',1,'Limits< int16_t >::min'],['../struct_limits_3_01int32__t_01_4.html#af336a1b22a8ed6a83a4cfb5bf8869771',1,'Limits< int32_t >::min'],['../struct_limits_3_01int64__t_01_4.html#a1c90fb96af515badaccaa835b08f7428',1,'Limits< int64_t >::min'],['../struct_limits_3_01half_01_4.html#aca7b036c257878bf1b80912fb5d4516d',1,'Limits< half >::min'],['../struct_limits_3_01float_01_4.html#a3225e334d372ee86128c89a440d8648f',1,'Limits< float >::min'],['../struct_limits_3_01bfloat16__t_01_4.html#a2fd1811b9f615b2b897904bc27d1cb49',1,'Limits< bfloat16_t >::min'],['../struct_limits_3_01bool_01_4.html#a139f787b57536d455490b8ef801d37cc',1,'Limits< bool >::min'],['../struct_limits_3_01complex64__t_01_4.html#aa67b04aa7abcd67f7af0808737ab8e14',1,'Limits< complex64_t >::min'],['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#adaed80031f5ca0ff69d30ec4c5d0c98f',1,'metal::_numeric_limits_impl< bfloat16_t >::min()'],['../namespacemetal.html#a6653b28c9473087141eddce39878d4d3',1,'metal::min()'],['../namespacemetal_1_1fast.html#a3e958e56a4712687c381a0b64d123e61',1,'metal::fast::min()'],['../namespacemetal_1_1precise.html#afed0da2f7df3505b5dffa2389c3cb36e',1,'metal::precise::min()'],['../group__ops.html#gab27599802617a4c8f9964ab5f4ffee12',1,'mlx::core::min(const array &a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga0140b91e9cdfc3fef0da8e332f65a9e8',1,'mlx::core::min(const array &a, StreamOrDevice s={})'],['../group__ops.html#ga6efb83cd46436678c8f8c4af15cc00f5',1,'mlx::core::min(const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#ga36fa315eef677f4143868f552cd26d03',1,'mlx::core::min(const array &a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
|
||||
['min3_48',['min3',['../namespacemetal.html#a005510c8c0f964ce2b8aad3ba76a7a3f',1,'metal::min3()'],['../namespacemetal_1_1fast.html#a606a4c1b34ce05ea89ca5af81724036f',1,'metal::fast::min3()'],['../namespacemetal_1_1precise.html#a4d37ce31c3549ca4772a4ee29798e231',1,'metal::precise::min3()']]],
|
||||
['min_5fexponent_49',['min_exponent',['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#a13829f8c7a7c0efdc8946eff5d3c9470',1,'metal::_numeric_limits_impl< bfloat16_t >']]],
|
||||
['min_5fexponent10_50',['min_exponent10',['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#aeaed172780720e06b8731cef3177e277',1,'metal::_numeric_limits_impl< bfloat16_t >']]],
|
||||
['minimum_51',['Minimum',['../struct_minimum.html',1,'Minimum'],['../structmlx_1_1core_1_1detail_1_1_minimum.html',1,'mlx::core::detail::Minimum'],['../classmlx_1_1core_1_1_minimum.html',1,'mlx::core::Minimum'],['../classmlx_1_1core_1_1_minimum.html#ab0f2ce17108df44b82cff68886b0f6f5',1,'mlx::core::Minimum::Minimum()']]],
|
||||
['minimum_52',['minimum',['../group__ops.html#ga49ba00c090f81f331c91b0c97040bce0',1,'mlx::core']]],
|
||||
['mlx_53',['mlx',['../namespacemlx.html',1,'']]],
|
||||
['mlx_2eh_54',['mlx.h',['../mlx_8h.html',1,'']]],
|
||||
['mlx_3a_3acore_55',['core',['../namespacemlx_1_1core.html',1,'mlx']]],
|
||||
['mlx_3a_3acore_3a_3aallocator_56',['allocator',['../namespacemlx_1_1core_1_1allocator.html',1,'mlx::core']]],
|
||||
['mlx_3a_3acore_3a_3adetail_57',['detail',['../namespacemlx_1_1core_1_1detail.html',1,'mlx::core']]],
|
||||
['mlx_3a_3acore_3a_3adistributed_58',['distributed',['../namespacemlx_1_1core_1_1distributed.html',1,'mlx::core']]],
|
||||
['mlx_3a_3acore_3a_3adistributed_3a_3adetail_59',['detail',['../namespacemlx_1_1core_1_1distributed_1_1detail.html',1,'mlx::core::distributed']]],
|
||||
['mlx_3a_3acore_3a_3afast_60',['fast',['../namespacemlx_1_1core_1_1fast.html',1,'mlx::core']]],
|
||||
['mlx_3a_3acore_3a_3afft_61',['fft',['../namespacemlx_1_1core_1_1fft.html',1,'mlx::core']]],
|
||||
['mlx_3a_3acore_3a_3aio_62',['io',['../namespacemlx_1_1core_1_1io.html',1,'mlx::core']]],
|
||||
['mlx_3a_3acore_3a_3alinalg_63',['linalg',['../namespacemlx_1_1core_1_1linalg.html',1,'mlx::core']]],
|
||||
['mlx_3a_3acore_3a_3ametal_64',['metal',['../namespacemlx_1_1core_1_1metal.html',1,'mlx::core']]],
|
||||
['mlx_3a_3acore_3a_3arandom_65',['random',['../namespacemlx_1_1core_1_1random.html',1,'mlx::core']]],
|
||||
['mlx_3a_3acore_3a_3ascheduler_66',['scheduler',['../namespacemlx_1_1core_1_1scheduler.html',1,'mlx::core']]],
|
||||
['mlx_3a_3asteel_67',['steel',['../namespacemlx_1_1steel.html',1,'mlx']]],
|
||||
['mlx_5fatomic_68',['mlx_atomic',['../structmlx__atomic.html',1,'']]],
|
||||
['mlx_5fatomic_3c_20t_2c_20enable_5fif_5ft_3c_20is_5fmetal_5fatomic_3c_20t_20_3e_20_3e_20_3e_69',['mlx_atomic< T, enable_if_t< is_metal_atomic< T > > >',['../structmlx__atomic_3_01_t_00_01enable__if__t_3_01is__metal__atomic_3_01_t_01_4_01_4_01_4.html',1,'']]],
|
||||
['mlx_5fatomic_5fcompare_5fexchange_5fweak_5fexplicit_70',['mlx_atomic_compare_exchange_weak_explicit',['../atomic_8h.html#ad7f32327ff66354cfa2f0cfdac79316f',1,'mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic< T > *object, thread T *expected, T val, size_t offset): atomic.h'],['../atomic_8h.html#aa8f47b2e9b95d4b00ad51f08b070deb5',1,'mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic< T > *object, thread uint *expected, uint val, size_t offset): atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fadd_5fexplicit_71',['mlx_atomic_fetch_add_explicit',['../atomic_8h.html#aad448d9e06e001700b65ca8317216a3b',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fand_5fexplicit_72',['mlx_atomic_fetch_and_explicit',['../atomic_8h.html#a253e3c870c0ddc7c28ab2f6ca2c3eae5',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmax_5fexplicit_73',['mlx_atomic_fetch_max_explicit',['../atomic_8h.html#ac480f2b459a8ad9095cee353e152d00c',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmax_5fexplicit_3c_20float_20_3e_74',['mlx_atomic_fetch_max_explicit< float >',['../atomic_8h.html#a1dce2abfa16417122c4d2bf261129ae4',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmin_5fexplicit_75',['mlx_atomic_fetch_min_explicit',['../atomic_8h.html#a2ec33dca0039bd944d73d1c2b378cc19',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmin_5fexplicit_3c_20float_20_3e_76',['mlx_atomic_fetch_min_explicit< float >',['../atomic_8h.html#ab7d1dc49f319f239b7ee0b7c72976dd0',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmul_5fexplicit_77',['mlx_atomic_fetch_mul_explicit',['../atomic_8h.html#adfdbea60436f14f1af9ce36e2a0a77a3',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5for_5fexplicit_78',['mlx_atomic_fetch_or_explicit',['../atomic_8h.html#ab7391f197001471e4788312bdb6ab37a',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5fload_5fexplicit_79',['mlx_atomic_load_explicit',['../atomic_8h.html#a253a4e8c2c5768a069e2791b627dfc99',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5fstore_5fexplicit_80',['mlx_atomic_store_explicit',['../atomic_8h.html#a0ae453140b0819a4c02f265334de98c0',1,'atomic.h']]],
|
||||
['mlx_5flapack_5ffunc_81',['MLX_LAPACK_FUNC',['../lapack_8h.html#ae22db9704827bf013a0a61f21a47464b',1,'lapack.h']]],
|
||||
['mlx_5fmtl_5fconst_82',['MLX_MTL_CONST',['../kernels_2gemv__masked_8h.html#a0386011c52d03e60885a31e6fbd903dd',1,'MLX_MTL_CONST: gemv_masked.h'],['../quantized_8h.html#a0386011c52d03e60885a31e6fbd903dd',1,'MLX_MTL_CONST: quantized.h'],['../sort_8h.html#a0386011c52d03e60885a31e6fbd903dd',1,'MLX_MTL_CONST: sort.h']]],
|
||||
['mlx_5fmtl_5floop_5funroll_83',['MLX_MTL_LOOP_UNROLL',['../sort_8h.html#ad34b622323cebef136669fedd7229515',1,'sort.h']]],
|
||||
['mlx_5fmtl_5fpragma_5funroll_84',['MLX_MTL_PRAGMA_UNROLL',['../kernels_2gemv__masked_8h.html#a069b682d7d21827461544817d722bfd3',1,'MLX_MTL_PRAGMA_UNROLL: gemv_masked.h'],['../backend_2metal_2kernels_2utils_8h.html#a069b682d7d21827461544817d722bfd3',1,'MLX_MTL_PRAGMA_UNROLL: utils.h']]],
|
||||
['mlxconvparams_85',['MLXConvParams',['../struct_m_l_x_conv_params.html',1,'']]],
|
||||
['mlxconvparams_3c_202_20_3e_86',['MLXConvParams< 2 >',['../struct_m_l_x_conv_params.html',1,'']]],
|
||||
['mlxfastattentionparams_87',['MLXFastAttentionParams',['../struct_m_l_x_fast_attention_params.html',1,'']]],
|
||||
['mlxscaleddotproductattentionparams_88',['MLXScaledDotProductAttentionParams',['../struct_m_l_x_scaled_dot_product_attention_params.html',1,'']]],
|
||||
['mma_89',['mma',['../structmlx_1_1steel_1_1_base_m_m_a_frag_3_01_t_00_018_00_018_01_4.html#a8028512f5a3d2b6acaf966be529627a3',1,'mlx::steel::BaseMMAFrag< T, 8, 8 >::mma(thread frag_type &D, thread frag_type &A, thread frag_type &B, thread frag_type &C)'],['../structmlx_1_1steel_1_1_base_m_m_a_frag_3_01_t_00_018_00_018_01_4.html#a1868f57d57c8adedab2c58492ec76946',1,'mlx::steel::BaseMMAFrag< T, 8, 8 >::mma(thread mat_type &D, thread mat_type &A, thread mat_type &B, thread mat_type &C)'],['../structmlx_1_1steel_1_1_block_m_m_a.html#a6a2c2a6d5e767d52c41b42a9d36086b0',1,'mlx::steel::BlockMMA::mma()']]],
|
||||
['mma_2eh_90',['mma.h',['../mma_8h.html',1,'']]],
|
||||
['mma_5ft_91',['mma_t',['../structmlx_1_1steel_1_1_g_e_m_m_kernel.html#add8c6a31011a4895667c2a94a5af3782',1,'mlx::steel::GEMMKernel']]],
|
||||
['mmafrag_5facc_5ft_92',['MMAFrag_acc_t',['../structmlx_1_1steel_1_1_block_m_m_a.html#ae2c42cb6d0dde785859164c195f4d13c',1,'mlx::steel::BlockMMA']]],
|
||||
['mmafrag_5ft_93',['MMAFrag_t',['../structmlx_1_1steel_1_1_m_m_a_tile.html#abe33de70e34300745bad9aa822fd0382',1,'mlx::steel::MMATile']]],
|
||||
['mmatile_94',['MMATile',['../structmlx_1_1steel_1_1_m_m_a_tile.html',1,'mlx::steel::MMATile< T, kTileRows_, kTileCols_, MMAFrag_ >'],['../structmlx_1_1steel_1_1_m_m_a_tile.html#aa3fb310dd08ec23c334511f7b316d1b6',1,'mlx::steel::MMATile::MMATile()']]],
|
||||
['mmatile_3c_20float_2c_201_2c_20tn_2c_20mlx_3a_3asteel_3a_3abasemmafrag_20_3e_95',['MMATile< float, 1, TN, mlx::steel::BaseMMAFrag >',['../structmlx_1_1steel_1_1_m_m_a_tile.html',1,'mlx::steel']]],
|
||||
['mmatile_3c_20float_2c_20tm_2c_201_2c_20mlx_3a_3asteel_3a_3abasemmafrag_20_3e_96',['MMATile< float, TM, 1, mlx::steel::BaseMMAFrag >',['../structmlx_1_1steel_1_1_m_m_a_tile.html',1,'mlx::steel']]],
|
||||
['mmatile_3c_20float_2c_20tm_2c_20tn_2c_20mlx_3a_3asteel_3a_3abasemmafrag_20_3e_97',['MMATile< float, TM, TN, mlx::steel::BaseMMAFrag >',['../structmlx_1_1steel_1_1_m_m_a_tile.html',1,'mlx::steel']]],
|
||||
['move_5fshared_5fbuffer_98',['move_shared_buffer',['../classmlx_1_1core_1_1array.html#acce00db63e0f3d80f797b02397ade836',1,'mlx::core::array::move_shared_buffer(array other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)'],['../classmlx_1_1core_1_1array.html#a38d7ad605f8282e5e49d0c09e0555c78',1,'mlx::core::array::move_shared_buffer(array other)']]],
|
||||
['moveaxis_99',['moveaxis',['../group__ops.html#ga24067d10a842db2c9d509ea48135a2c3',1,'mlx::core']]],
|
||||
['mpinplace_100',['MPINPLACE',['../namespacepocketfft_1_1detail.html#af5eedf3cdfc83c0a30807092c39a9ce2',1,'pocketfft::detail']]],
|
||||
['mtl_5fconst_101',['MTL_CONST',['../defines_8h.html#a767ed9f2604de22b259cee02c4ce1d22',1,'defines.h']]],
|
||||
['mtl_5fdevice_102',['mtl_device',['../classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653',1,'mlx::core::metal::Device']]],
|
||||
['mtl_5fresidency_5fset_103',['mtl_residency_set',['../classmlx_1_1core_1_1metal_1_1_residency_set.html#ac4bfe5ef5e2eaebc458a1ed1953d15e9',1,'mlx::core::metal::ResidencySet']]],
|
||||
['mtlfclist_104',['MTLFCList',['../namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54',1,'mlx::core::metal']]],
|
||||
['mtx_105',['mtx',['../structmlx_1_1core_1_1scheduler_1_1_stream_thread.html#a70410c9e612f871663929f1e8441a976',1,'mlx::core::scheduler::StreamThread']]],
|
||||
['multi_5fiter_106',['multi_iter',['../classpocketfft_1_1detail_1_1multi__iter.html',1,'pocketfft::detail::multi_iter< N >'],['../classpocketfft_1_1detail_1_1multi__iter.html#a9be43bb18840202da6d17988fccc64b9',1,'pocketfft::detail::multi_iter::multi_iter()']]],
|
||||
['multiply_107',['Multiply',['../structmlx_1_1core_1_1detail_1_1_multiply.html',1,'mlx::core::detail::Multiply'],['../classmlx_1_1core_1_1_multiply.html',1,'mlx::core::Multiply'],['../struct_multiply.html',1,'Multiply'],['../classmlx_1_1core_1_1_multiply.html#aca5c50f900321f3eb4d6fbcbc225c00c',1,'mlx::core::Multiply::Multiply()']]],
|
||||
['multiply_108',['multiply',['../group__ops.html#gaf57392e641640b5d06e4c99518391c38',1,'mlx::core']]],
|
||||
['multivariate_5fnormal_109',['multivariate_normal',['../namespacemlx_1_1core_1_1random.html#a8c37da3c1c0c561cad7499d6d9db81fb',1,'mlx::core::random']]]
|
||||
];
|
||||
|
2
docs/build/html/search/functions_1.js
vendored
2
docs/build/html/search/functions_1.js
vendored
@ -22,7 +22,7 @@ var searchData=
|
||||
['all_19',['all',['../group__ops.html#ga3b1b90ef1275ca17655b6d7f25d3ee68',1,'mlx::core::all(const array &a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga3689e12e8f42dadb4cbe2b07dc4099f4',1,'mlx::core::all(const array &a, StreamOrDevice s={})'],['../group__ops.html#gac0919c6ba53aea35a7683dea7e9a9a59',1,'mlx::core::all(const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#gae2d5fcc5b62d673cca76c08b7b4afbbc',1,'mlx::core::all(const array &a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
|
||||
['all_5fgather_20',['all_gather',['../namespacemlx_1_1core_1_1distributed_1_1detail.html#aeb5a1726358213bc75756506f7b54d04',1,'mlx::core::distributed::detail::all_gather()'],['../namespacemlx_1_1core_1_1distributed.html#a82ef5e8cc7ac62cd228e51b1c1b77cb7',1,'mlx::core::distributed::all_gather()']]],
|
||||
['all_5freduce_21',['all_reduce',['../reduce__all_8h.html#a99ef48ae72b3e715c5f4d7ea07cd213d',1,'reduce_all.h']]],
|
||||
['all_5freduce_5fdispatch_22',['all_reduce_dispatch',['../namespacemlx_1_1core.html#af7b7ca7c6aa87558d9f98cee5c7a99a8',1,'mlx::core']]],
|
||||
['all_5freduce_5fdispatch_22',['all_reduce_dispatch',['../namespacemlx_1_1core.html#a3ab0fd997d9a35782106ff083a72e098',1,'mlx::core']]],
|
||||
['all_5fsum_23',['all_sum',['../namespacemlx_1_1core_1_1distributed_1_1detail.html#aa1d225b25f7b6426c48c5e35860ee960',1,'mlx::core::distributed::detail::all_sum()'],['../namespacemlx_1_1core_1_1distributed.html#a67ccb1a5445fc6f5db49dd36a15e5980',1,'mlx::core::distributed::all_sum()']]],
|
||||
['allclose_24',['allclose',['../group__ops.html#gaf0cd4257de7542daf9faf5e605e31020',1,'mlx::core']]],
|
||||
['allgather_25',['AllGather',['../classmlx_1_1core_1_1distributed_1_1_all_gather.html#af4b10a5b61f160fb64353057c185b661',1,'mlx::core::distributed::AllGather']]],
|
||||
|
3
docs/build/html/search/functions_11.js
vendored
3
docs/build/html/search/functions_11.js
vendored
@ -22,5 +22,6 @@ var searchData=
|
||||
['quantizedmatmul_19',['QuantizedMatmul',['../classmlx_1_1core_1_1_quantized_matmul.html#a5bd164d038d9dc21919f7e0bfdeaa25c',1,'mlx::core::QuantizedMatmul']]],
|
||||
['quiet_5fnan_20',['quiet_NaN',['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#aebeb07c01984be246bc2d1b8f8e4ac7b',1,'metal::_numeric_limits_impl< bfloat16_t >']]],
|
||||
['qvm_21',['qvm',['../quantized_8h.html#ad84f7d5ab9e32dbbe3ca759ae5d5d5c5',1,'quantized.h']]],
|
||||
['qvm_5fimpl_22',['qvm_impl',['../quantized_8h.html#a4a8c8db7d5d480733726fd6d1a645e12',1,'quantized.h']]]
|
||||
['qvm_5fimpl_22',['qvm_impl',['../quantized_8h.html#a1546533c5b925b2fbb3bec870ec7487a',1,'quantized.h']]],
|
||||
['qvm_5fsplit_5fk_23',['qvm_split_k',['../quantized_8h.html#ab8243818512d6078d23e6ffb65fd7bb8',1,'quantized.h']]]
|
||||
];
|
||||
|
2
docs/build/html/search/functions_13.js
vendored
2
docs/build/html/search/functions_13.js
vendored
@ -17,7 +17,7 @@ var searchData=
|
||||
['scatter_5fprod_14',['scatter_prod',['../group__ops.html#ga3708b5bcb61e2c63d213c4ce6ad0ffc0',1,'mlx::core::scatter_prod(const array &a, const std::vector< array > &indices, const array &updates, const std::vector< int > &axes, StreamOrDevice s={})'],['../group__ops.html#gaf83c53c453faa9083ba27e4b97539339',1,'mlx::core::scatter_prod(const array &a, const array &indices, const array &updates, int axis, StreamOrDevice s={})']]],
|
||||
['scheduler_15',['Scheduler',['../classmlx_1_1core_1_1scheduler_1_1_scheduler.html#a3ae42aed78a2200e9d02776fcd2316ba',1,'mlx::core::scheduler::Scheduler::Scheduler()'],['../classmlx_1_1core_1_1scheduler_1_1_scheduler.html#a61a74e3628899e66dde600e24a750648',1,'mlx::core::scheduler::Scheduler::Scheduler(const Scheduler &)=delete'],['../classmlx_1_1core_1_1scheduler_1_1_scheduler.html#ac3f77b7c93220dadd0b3bb2e903b7059',1,'mlx::core::scheduler::Scheduler::Scheduler(Scheduler &&)=delete']]],
|
||||
['scheduler_16',['scheduler',['../namespacemlx_1_1core_1_1scheduler.html#ae856e468c2f7c8f8ec672522cc13730b',1,'mlx::core::scheduler']]],
|
||||
['sdpa_5fvector_17',['sdpa_vector',['../sdpa__vector_8h.html#a6f0d7918430064bab910bdaa6c64e927',1,'sdpa_vector.h']]],
|
||||
['sdpa_5fvector_17',['sdpa_vector',['../sdpa__vector_8h.html#a4bf36f16e16c1c62d9b243573568e5ae',1,'sdpa_vector.h']]],
|
||||
['seed_18',['seed',['../classmlx_1_1core_1_1random_1_1_key_sequence.html#a9f19c5da2031cba50d0ff996924347d8',1,'mlx::core::random::KeySequence::seed()'],['../namespacemlx_1_1core_1_1random.html#ac4ad325b613257306df74595d3d0e23b',1,'mlx::core::random::seed()']]],
|
||||
['seek_19',['seek',['../structmlx_1_1core_1_1_contiguous_iterator.html#a24719ee9e8667885d29c2ad74445520c',1,'mlx::core::ContiguousIterator::seek()'],['../classmlx_1_1core_1_1io_1_1_reader.html#acea55078bd39ccaa27a9a36f17a39cd1',1,'mlx::core::io::Reader::seek()'],['../classmlx_1_1core_1_1io_1_1_writer.html#a9c1716dda53aa36faea9c8fb1a3e34d4',1,'mlx::core::io::Writer::seek()'],['../classmlx_1_1core_1_1io_1_1_parallel_file_reader.html#a673c16b669f3cee13f387b7b0a1f39f7',1,'mlx::core::io::ParallelFileReader::seek()'],['../classmlx_1_1core_1_1io_1_1_file_writer.html#a9646f4ea048ae58719daeb588e2de433',1,'mlx::core::io::FileWriter::seek()']]],
|
||||
['select_20',['Select',['../classmlx_1_1core_1_1_select.html#a6f833fe55dd68ad3726bbf9a8f75eec9',1,'mlx::core::Select']]],
|
||||
|
164
docs/build/html/search/functions_3.js
vendored
164
docs/build/html/search/functions_3.js
vendored
@ -18,85 +18,87 @@ var searchData=
|
||||
['clip_15',['clip',['../group__ops.html#ga157cd7c23f9b306fee2e1eb2b9bf1dd8',1,'mlx::core']]],
|
||||
['cmplx_16',['cmplx',['../structpocketfft_1_1detail_1_1cmplx.html#a5b1ce506f1023f5254025ac81b831a2c',1,'pocketfft::detail::cmplx::cmplx()'],['../structpocketfft_1_1detail_1_1cmplx.html#a05491b4f1f22ca0bc49012f6a1c1710a',1,'pocketfft::detail::cmplx::cmplx(T r_, T i_)']]],
|
||||
['cndarr_17',['cndarr',['../classpocketfft_1_1detail_1_1cndarr.html#abf73f1b4ddcfb27d7f85cfa441607129',1,'pocketfft::detail::cndarr']]],
|
||||
['col_5freduce_5flooped_18',['col_reduce_looped',['../reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385',1,'reduce_col.h']]],
|
||||
['col_5freduce_5fsmall_19',['col_reduce_small',['../reduce__col_8h.html#adf7aeb18cd1d5042cf6d9b46b582d8ce',1,'reduce_col.h']]],
|
||||
['collapse_5fcontiguous_5fdims_20',['collapse_contiguous_dims',['../namespacemlx_1_1core.html#a38fe6ec5220d13d96c7dad7556d2b613',1,'mlx::core::collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< int64_t > > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#af2895f9b0083efd8221275eb8cadccbe',1,'mlx::core::collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< size_t > > &strides, size_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#a90e2b6edc0fe82230cb93f5ea39febb4',1,'mlx::core::collapse_contiguous_dims(const std::vector< array > &xs, size_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#ac813412cce77fc1340dcfefc6e099276',1,'mlx::core::collapse_contiguous_dims(Arrays &&... xs)'],['../namespacemlx_1_1core.html#aab3cc7f3808934ae0727b920eba231bd',1,'mlx::core::collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< int64_t > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#a1e0cbcf109d32794ffc8efc7302ba9b0',1,'mlx::core::collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< size_t > &strides, size_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#a4ee50bfb240512d0c0ce151dfe2c74ef',1,'mlx::core::collapse_contiguous_dims(const array &a, size_t size_cap=std::numeric_limits< int32_t >::max())']]],
|
||||
['commandencoder_21',['CommandEncoder',['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#a2334774486f447213ee997e55c2e52a3',1,'mlx::core::metal::CommandEncoder::CommandEncoder(MTL::CommandBuffer *cbuf)'],['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#ac68ca977b5bde5434284ce7979647f14',1,'mlx::core::metal::CommandEncoder::CommandEncoder(const CommandEncoder &)=delete']]],
|
||||
['commit_5fcommand_5fbuffer_22',['commit_command_buffer',['../classmlx_1_1core_1_1metal_1_1_device.html#a95248f1387824067fd4fed23ace5ac0c',1,'mlx::core::metal::Device']]],
|
||||
['communication_5fstream_23',['communication_stream',['../namespacemlx_1_1core_1_1distributed_1_1detail.html#ac3612edf0e0e18c1e4ba0ce7c6e35cd6',1,'mlx::core::distributed::detail']]],
|
||||
['compile_24',['compile',['../namespacemlx_1_1core.html#a3ac798e65e59fe10b7fb5c522efce782',1,'mlx::core::compile()'],['../namespacemlx_1_1core_1_1detail.html#ac3b7b09892ff7290d5f3ef26cb444329',1,'mlx::core::detail::compile(const std::function< std::vector< array >(const std::vector< array > &)> &fun, std::uintptr_t fun_id, bool shapeless=false, std::vector< uint64_t > constants={})']]],
|
||||
['compile_5favailable_5ffor_5fdevice_25',['compile_available_for_device',['../namespacemlx_1_1core_1_1detail.html#aeeff2ba6ec3d9d4ed090de6d2681dbc2',1,'mlx::core::detail']]],
|
||||
['compile_5fclear_5fcache_26',['compile_clear_cache',['../namespacemlx_1_1core_1_1detail.html#a3fb927c209b946aefebb195993fbe4cf',1,'mlx::core::detail']]],
|
||||
['compile_5ferase_27',['compile_erase',['../namespacemlx_1_1core_1_1detail.html#a69eb76a14f845ca000f1ccb2edda0175',1,'mlx::core::detail']]],
|
||||
['compiled_28',['Compiled',['../classmlx_1_1core_1_1_compiled.html#a2d8cefff835c419a48a077d306b8e051',1,'mlx::core::Compiled']]],
|
||||
['compiled_5fallocate_5foutputs_29',['compiled_allocate_outputs',['../namespacemlx_1_1core.html#ab8c3c4fc05745f586de922c8266f4fce',1,'mlx::core']]],
|
||||
['compiled_5fcheck_5fcontiguity_30',['compiled_check_contiguity',['../namespacemlx_1_1core.html#a3b900ab319948c5a01a3ecd30a709027',1,'mlx::core']]],
|
||||
['complex128_5ft_31',['complex128_t',['../structmlx_1_1core_1_1complex128__t.html#aa15d0b805f8790f7c7b76fc7b9d677e0',1,'mlx::core::complex128_t::complex128_t(double v, double u)'],['../structmlx_1_1core_1_1complex128__t.html#abf2842253b874f9f13f39ea68a89e5b6',1,'mlx::core::complex128_t::complex128_t(std::complex< double > v)'],['../structmlx_1_1core_1_1complex128__t.html#a526fba96d7e815360cb4226af085a1bf',1,'mlx::core::complex128_t::complex128_t(T x)']]],
|
||||
['complex64_5ft_32',['complex64_t',['../structcomplex64__t.html#adbd392a5e92d31997380ad0a38be4be8',1,'complex64_t::complex64_t(float real, float imag)'],['../structcomplex64__t.html#a29782289bb90d6294099667b86509cd3',1,'complex64_t::complex64_t()'],['../structcomplex64__t.html#a905b048d70eb8d748a62454268242291',1,'complex64_t::complex64_t() threadgroup'],['../structcomplex64__t.html#a33a2452eb33b5ed53655773539c357a5',1,'complex64_t::complex64_t(T x) thread'],['../structcomplex64__t.html#a89b65ace8588b7bf215355f705eb23d9',1,'complex64_t::complex64_t(T x) threadgroup'],['../structcomplex64__t.html#ac81b486f642fb3b26c5d659917bdbcd0',1,'complex64_t::complex64_t(T x) device'],['../structcomplex64__t.html#a0a27a41206400f1e62b60ceb56960c93',1,'complex64_t::complex64_t(T x) const ant'],['../structmlx_1_1core_1_1complex64__t.html#a697cc973ae27d63c8e00d830e780bd8c',1,'mlx::core::complex64_t::complex64_t(float v, float u)'],['../structmlx_1_1core_1_1complex64__t.html#ae065e39938f9c4374b4116f4c67d4d09',1,'mlx::core::complex64_t::complex64_t(std::complex< float > v)'],['../structmlx_1_1core_1_1complex64__t.html#a2232cbbe591a9d2bc228cb23fac38b50',1,'mlx::core::complex64_t::complex64_t(T x)']]],
|
||||
['complex_5fmul_33',['complex_mul',['../radix_8h.html#a5bfc53b531214c9ce277bebc18aa67d6',1,'radix.h']]],
|
||||
['complex_5fmul_5fconj_34',['complex_mul_conj',['../radix_8h.html#a0e2dfd3d1dda09f47ccc64eec35629f3',1,'radix.h']]],
|
||||
['compute_5fstrided_5findices_35',['compute_strided_indices',['../struct_read_writer.html#a7c903fbb8b85a856ba5564d7df537cdf',1,'ReadWriter']]],
|
||||
['concatenate_36',['Concatenate',['../classmlx_1_1core_1_1_concatenate.html#acff07853de2d31faeec7c4ca40ce0888',1,'mlx::core::Concatenate']]],
|
||||
['concatenate_37',['concatenate',['../group__ops.html#gabdc36fa65697d0361c8d67495de77129',1,'mlx::core::concatenate(const std::vector< array > &arrays, int axis, StreamOrDevice s={})'],['../group__ops.html#gaa95c34ca3a8877f2c50cb60e7fa312b8',1,'mlx::core::concatenate(const std::vector< array > &arrays, StreamOrDevice s={})']]],
|
||||
['concatenate_5fgpu_38',['concatenate_gpu',['../namespacemlx_1_1core.html#a050299d0d366ca5c9d09d1004dcc3e7d',1,'mlx::core']]],
|
||||
['concurrentcontext_39',['ConcurrentContext',['../structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html#aee044d7729739c96e845823f9ecc5174',1,'mlx::core::metal::CommandEncoder::ConcurrentContext']]],
|
||||
['conj_40',['conj',['../namespacepocketfft_1_1detail.html#a66d79051d502046a9b9f103e744dbad3',1,'pocketfft::detail']]],
|
||||
['conjugate_41',['Conjugate',['../classmlx_1_1core_1_1_conjugate.html#a627f9e6a8729fb3ffb3ca3228d007c87',1,'mlx::core::Conjugate']]],
|
||||
['conjugate_42',['conjugate',['../group__ops.html#ga5b596906bf8cdc8d97ed6ddc9aeb4c23',1,'mlx::core']]],
|
||||
['contiguous_5fscan_43',['contiguous_scan',['../scan_8h.html#a60d279b9add7d56639bb209408f09d79',1,'scan.h']]],
|
||||
['contiguousiterator_44',['ContiguousIterator',['../structmlx_1_1core_1_1_contiguous_iterator.html#a68794af4a442d3d8ac4647817af8e1f6',1,'mlx::core::ContiguousIterator::ContiguousIterator()'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a6cb378408b6f546eeb6ade1a4faafe3c',1,'mlx::core::ContiguousIterator::ContiguousIterator(const array &a)'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a16bdacb53f65b7284068cd49d4cba292',1,'mlx::core::ContiguousIterator::ContiguousIterator(const std::vector< int > &shape, const std::vector< StrideT > &strides, int dims)']]],
|
||||
['conv_45',['conv',['../namespacemlx_1_1core_1_1metal.html#ab1704e853394c725668c06752ebb5c24',1,'mlx::core::metal']]],
|
||||
['conv1d_46',['conv1d',['../group__ops.html#ga30d47e08093c03a3676f235f9f559411',1,'mlx::core']]],
|
||||
['conv2d_47',['conv2d',['../group__ops.html#ga73b02833229678786e7f302d458d5a83',1,'mlx::core']]],
|
||||
['conv2dinputblockloadergeneral_48',['Conv2DInputBlockLoaderGeneral',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_general.html#a1d83af561a483432bf8dcb42e734b23b',1,'mlx::steel::Conv2DInputBlockLoaderGeneral']]],
|
||||
['conv2dinputblockloaderlargefilter_49',['Conv2DInputBlockLoaderLargeFilter',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_large_filter.html#a8755116a535539744e4947bc69f9c50f',1,'mlx::steel::Conv2DInputBlockLoaderLargeFilter']]],
|
||||
['conv2dinputblockloadersmallchannels_50',['Conv2DInputBlockLoaderSmallChannels',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_channels.html#ab9fd3fdeab94470dde3326f1dd5c455a',1,'mlx::steel::Conv2DInputBlockLoaderSmallChannels']]],
|
||||
['conv2dinputblockloadersmallfilter_51',['Conv2DInputBlockLoaderSmallFilter',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_filter.html#a0a2cbf57c51cd928722e3f06aafcf933',1,'mlx::steel::Conv2DInputBlockLoaderSmallFilter']]],
|
||||
['conv2dweightblockloader_52',['Conv2DWeightBlockLoader',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader.html#a9a7dca3512b64cffb6eac305d795831c',1,'mlx::steel::Conv2DWeightBlockLoader']]],
|
||||
['conv2dweightblockloadergeneral_53',['Conv2DWeightBlockLoaderGeneral',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_general.html#ad0550fabbdc9297559381a5b488e9af1',1,'mlx::steel::Conv2DWeightBlockLoaderGeneral']]],
|
||||
['conv2dweightblockloadersmallchannels_54',['Conv2DWeightBlockLoaderSmallChannels',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_small_channels.html#ae1806ea1c19713819dee83a38ab35fa6',1,'mlx::steel::Conv2DWeightBlockLoaderSmallChannels']]],
|
||||
['conv3d_55',['conv3d',['../group__ops.html#ga6e9907d2f14dc4803e4306b3dbc4b3ca',1,'mlx::core']]],
|
||||
['conv_5fgeneral_56',['conv_general',['../group__ops.html#ga2236e5dfc7e52e28abf6c21675d0a51e',1,'mlx::core::conv_general(array input, array weight, std::vector< int > stride={}, std::vector< int > padding_lo={}, std::vector< int > padding_hi={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})'],['../group__ops.html#gab59f89942cd1efaadffe9e8762e3c99d',1,'mlx::core::conv_general(const array &input, const array &weight, std::vector< int > stride={}, std::vector< int > padding={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})']]],
|
||||
['conv_5ftranspose1d_57',['conv_transpose1d',['../group__ops.html#gaa30bf1adcd78d1c2595d07b215731714',1,'mlx::core']]],
|
||||
['conv_5ftranspose2d_58',['conv_transpose2d',['../group__ops.html#gaebb59971cb9bc45005dc1d398e4f0a3d',1,'mlx::core']]],
|
||||
['conv_5ftranspose3d_59',['conv_transpose3d',['../group__ops.html#ga8db814da631d9cd32a8d6563bf4ac530',1,'mlx::core']]],
|
||||
['convolution_60',['Convolution',['../classmlx_1_1core_1_1_convolution.html#a6f1de77b719bb13217b0d8c64cabb8ef',1,'mlx::core::Convolution']]],
|
||||
['copy_61',['Copy',['../classmlx_1_1core_1_1_copy.html#a6243e044af119105ffaaed7d405cd584',1,'mlx::core::Copy']]],
|
||||
['copy_62',['copy',['../namespacemlx_1_1core.html#a479648542a2bea151b947b18f0e79dd2',1,'mlx::core::copy()'],['../namespacemlx_1_1core_1_1metal.html#aa215e631e2680f04a591b88d91571719',1,'mlx::core::metal::copy()'],['../group__ops.html#gae306e93af12f774bd80bad6c231b09d6',1,'mlx::core::copy()']]],
|
||||
['copy_5fg_63',['copy_g',['../metal_2kernels_2copy_8h.html#a778ce2dbfbaa23b24bd5efbe68448c36',1,'copy.h']]],
|
||||
['copy_5fg_5fnd1_64',['copy_g_nd1',['../metal_2kernels_2copy_8h.html#aba4530a7db6a61ca36f50e4f5e58fb77',1,'copy.h']]],
|
||||
['copy_5fg_5fnd2_65',['copy_g_nd2',['../metal_2kernels_2copy_8h.html#aee678c7c31119f3e609685589f37490c',1,'copy.h']]],
|
||||
['copy_5fg_5fnd3_66',['copy_g_nd3',['../metal_2kernels_2copy_8h.html#a821f8f3f3891159a295c66fc25aed1ff',1,'copy.h']]],
|
||||
['copy_5fgg_67',['copy_gg',['../metal_2kernels_2copy_8h.html#a1e39c2683eeaf05955e7619fbd34aea5',1,'copy.h']]],
|
||||
['copy_5fgg_5fnd1_68',['copy_gg_nd1',['../metal_2kernels_2copy_8h.html#a3278d9c999718bee3ccbe2922f501bf1',1,'copy.h']]],
|
||||
['copy_5fgg_5fnd2_69',['copy_gg_nd2',['../metal_2kernels_2copy_8h.html#a3e2d3cc7f34f56170409b6735f51a950',1,'copy.h']]],
|
||||
['copy_5fgg_5fnd3_70',['copy_gg_nd3',['../metal_2kernels_2copy_8h.html#a59f43b5bffed936d7559ceb06a10aabd',1,'copy.h']]],
|
||||
['copy_5fgpu_71',['copy_gpu',['../namespacemlx_1_1core.html#addaa46a13ac2deb1d9ce621338320e0e',1,'mlx::core::copy_gpu(const array &src, array &out, CopyType ctype, const Stream &s)'],['../namespacemlx_1_1core.html#a6a6f4e46c8fc44fdc74c50ace02bcf38',1,'mlx::core::copy_gpu(const array &src, array &out, CopyType ctype)']]],
|
||||
['copy_5fgpu_5finplace_72',['copy_gpu_inplace',['../namespacemlx_1_1core.html#a69e30f5d30a6d72ac0ffe4886f24b7ba',1,'mlx::core::copy_gpu_inplace(const array &in, array &out, const std::vector< int > &data_shape, const std::vector< stride_t > &i_strides, const std::vector< stride_t > &o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype, const Stream &s)'],['../namespacemlx_1_1core.html#a8e1ccb0ed9387b0a789311d9f8964803',1,'mlx::core::copy_gpu_inplace(const array &src, array &out, CopyType ctype, const Stream &s)'],['../namespacemlx_1_1core.html#ae55b801b09ccf55cba96278163a9b1ef',1,'mlx::core::copy_gpu_inplace(const array &in, array &out, const std::vector< int64_t > &istride, int64_t ioffset, CopyType ctype, const Stream &s)']]],
|
||||
['copy_5fhartley_73',['copy_hartley',['../namespacepocketfft_1_1detail.html#abac3fcc8ce83800d228774f64c28d4c3',1,'pocketfft::detail::copy_hartley(const multi_iter< vlen > &it, const vtype_t< T > *src, ndarr< T > &dst)'],['../namespacepocketfft_1_1detail.html#ae7b44d2773d9d06a9787aff01d66b3ed',1,'pocketfft::detail::copy_hartley(const multi_iter< vlen > &it, const T *src, ndarr< T > &dst)']]],
|
||||
['copy_5finplace_74',['copy_inplace',['../namespacemlx_1_1core.html#a98495894a796b2cc6d022e7a03432c64',1,'mlx::core::copy_inplace(const array &src, array &dst, CopyType ctype)'],['../namespacemlx_1_1core.html#aad636e2d0b2f882cadd1b438f4daa9ed',1,'mlx::core::copy_inplace(const array &src, array &dst, const std::vector< int > &data_shape, const std::vector< stride_t > &i_strides, const std::vector< stride_t > &o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype)']]],
|
||||
['copy_5finput_75',['copy_input',['../namespacepocketfft_1_1detail.html#aff05be3064743c1143b19318ab12ad4a',1,'pocketfft::detail::copy_input(const multi_iter< vlen > &it, const cndarr< cmplx< T > > &src, cmplx< vtype_t< T > > *dst)'],['../namespacepocketfft_1_1detail.html#a30fc708f9d8f9cfa74194925c7863c0a',1,'pocketfft::detail::copy_input(const multi_iter< vlen > &it, const cndarr< T > &src, vtype_t< T > *dst)'],['../namespacepocketfft_1_1detail.html#a3387bd35f237870e42b8461769e6aec4',1,'pocketfft::detail::copy_input(const multi_iter< vlen > &it, const cndarr< T > &src, T *dst)']]],
|
||||
['copy_5foutput_76',['copy_output',['../namespacepocketfft_1_1detail.html#a1523a037300a8da05db210b802d9cb0e',1,'pocketfft::detail::copy_output(const multi_iter< vlen > &it, const cmplx< vtype_t< T > > *src, ndarr< cmplx< T > > &dst)'],['../namespacepocketfft_1_1detail.html#a21980853aca4d92ed06e3dcffe7ef660',1,'pocketfft::detail::copy_output(const multi_iter< vlen > &it, const vtype_t< T > *src, ndarr< T > &dst)'],['../namespacepocketfft_1_1detail.html#a310481c334e46674710ba794ad7403c0',1,'pocketfft::detail::copy_output(const multi_iter< vlen > &it, const T *src, ndarr< T > &dst)']]],
|
||||
['copy_5fs_77',['copy_s',['../metal_2kernels_2copy_8h.html#aef09f9b9475345b1bba121d037d222ea',1,'copy.h']]],
|
||||
['copy_5fs2_78',['copy_s2',['../metal_2kernels_2copy_8h.html#a8023e9335cc5334847a8d315042be3a3',1,'copy.h']]],
|
||||
['copy_5fshared_5fbuffer_79',['copy_shared_buffer',['../classmlx_1_1core_1_1array.html#a28df7a333d90a311c49bc4bce7a1ad6d',1,'mlx::core::array::copy_shared_buffer(const array &other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)'],['../classmlx_1_1core_1_1array.html#a92974c656c35a972ad241f80584bbd29',1,'mlx::core::array::copy_shared_buffer(const array &other)']]],
|
||||
['copy_5fv_80',['copy_v',['../metal_2kernels_2copy_8h.html#ae26a13e0c8e6c15f7b10078e65970659',1,'copy.h']]],
|
||||
['copy_5fv2_81',['copy_v2',['../metal_2kernels_2copy_8h.html#aee14a5326f53d9b30b0b38e27d180ef3',1,'copy.h']]],
|
||||
['cos_82',['Cos',['../classmlx_1_1core_1_1_cos.html#a2acb9fcf0901462189c476756fd99995',1,'mlx::core::Cos']]],
|
||||
['cos_83',['cos',['../namespacepocketfft_1_1detail.html#a499c1e8b7d79a5272af024f46c63ff9d',1,'pocketfft::detail::cos()'],['../namespacemetal.html#a2fa4778a6fe2fa43253ea724e5a608a3',1,'metal::cos()'],['../namespacemetal_1_1fast.html#a75b6bb32fa3870eda46a7bfc9f481f88',1,'metal::fast::cos()'],['../namespacemetal_1_1precise.html#ac4941f62e7d8ab9d7cabbd967aa9f220',1,'metal::precise::cos()'],['../group__ops.html#ga39dfdf72b556012aa35ff27a94116e74',1,'mlx::core::cos()']]],
|
||||
['cosh_84',['Cosh',['../classmlx_1_1core_1_1_cosh.html#a44e8ac2e09a55ec32e9dc6641eedc8f1',1,'mlx::core::Cosh']]],
|
||||
['cosh_85',['cosh',['../namespacemetal.html#a8a68a88cc110830d057dbd71431b93c0',1,'metal::cosh()'],['../namespacemetal_1_1fast.html#a31544ad9de28012a4ddda86e3966a77e',1,'metal::fast::cosh()'],['../namespacemetal_1_1precise.html#a72d86d508300a9b58f4ccbbe70da4fbc',1,'metal::precise::cosh()'],['../group__ops.html#ga2181b71cda88007a3092be4795ff0715',1,'mlx::core::cosh()']]],
|
||||
['cospi_86',['cospi',['../namespacemetal.html#a5c2f37939ad705ddea4409d3bedb8ce1',1,'metal::cospi()'],['../namespacemetal_1_1fast.html#a9906b41f75319b384ffb570cc94d67ce',1,'metal::fast::cospi()'],['../namespacemetal_1_1precise.html#a2392b78bd196efdbbac65901c4ab20e7',1,'metal::precise::cospi()']]],
|
||||
['cost_5fguess_87',['cost_guess',['../structpocketfft_1_1detail_1_1util.html#ad3d874bc3fb0048df2270779a15d4bd0',1,'pocketfft::detail::util']]],
|
||||
['count_5fdown_88',['count_down',['../classpocketfft_1_1detail_1_1threading_1_1latch.html#a81d6597189b40410e35f3cd653fd1342',1,'pocketfft::detail::threading::latch']]],
|
||||
['cross_89',['cross',['../namespacemlx_1_1core_1_1linalg.html#abcda3fbda45183c21e7f27aa0dde64e6',1,'mlx::core::linalg']]],
|
||||
['cummax_90',['cummax',['../group__ops.html#gaee37cac8476e8f8d666bcded5bc59143',1,'mlx::core']]],
|
||||
['cummin_91',['cummin',['../group__ops.html#ga19c1bf6929fe8d66b9cd408946aea6a8',1,'mlx::core']]],
|
||||
['cumprod_92',['cumprod',['../group__ops.html#ga0d71dfbc14ef3ed564b0c5ee26af680f',1,'mlx::core']]],
|
||||
['cumsum_93',['cumsum',['../group__ops.html#gaddc825a5c173e195ab0fda83ad630420',1,'mlx::core']]],
|
||||
['custom_94',['Custom',['../classmlx_1_1core_1_1fast_1_1_custom.html#a4186fea23f7156c38960426821fca313',1,'mlx::core::fast::Custom']]],
|
||||
['custom_5ffunction_95',['custom_function',['../namespacemlx_1_1core.html#a8d3ca5fbaecdb995660c24cde5aeebaf',1,'mlx::core']]],
|
||||
['custom_5fvjp_96',['custom_vjp',['../namespacemlx_1_1core.html#a9290596250fa308df4c69b44483bb8aa',1,'mlx::core']]],
|
||||
['customkernel_97',['CustomKernel',['../classmlx_1_1core_1_1fast_1_1_custom_kernel.html#a954893e07f0d36715b4e1e414b6f2153',1,'mlx::core::fast::CustomKernel']]],
|
||||
['customtransforms_98',['CustomTransforms',['../classmlx_1_1core_1_1_custom_transforms.html#ab52abadb9c6f6db83d087c7b751be488',1,'mlx::core::CustomTransforms']]]
|
||||
['col_5freduce_5f2pass_18',['col_reduce_2pass',['../reduce__col_8h.html#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d',1,'reduce_col.h']]],
|
||||
['col_5freduce_5flongcolumn_19',['col_reduce_longcolumn',['../reduce__col_8h.html#a5b4f4c4c247ad341ff8d31dcbbbce0eb',1,'reduce_col.h']]],
|
||||
['col_5freduce_5flooped_20',['col_reduce_looped',['../reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385',1,'reduce_col.h']]],
|
||||
['col_5freduce_5fsmall_21',['col_reduce_small',['../reduce__col_8h.html#a7c378443a2b6f4d9210db8a21a9ac4f5',1,'reduce_col.h']]],
|
||||
['collapse_5fcontiguous_5fdims_22',['collapse_contiguous_dims',['../namespacemlx_1_1core.html#a38fe6ec5220d13d96c7dad7556d2b613',1,'mlx::core::collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< int64_t > > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#af2895f9b0083efd8221275eb8cadccbe',1,'mlx::core::collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< size_t > > &strides, size_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#a90e2b6edc0fe82230cb93f5ea39febb4',1,'mlx::core::collapse_contiguous_dims(const std::vector< array > &xs, size_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#ac813412cce77fc1340dcfefc6e099276',1,'mlx::core::collapse_contiguous_dims(Arrays &&... xs)'],['../namespacemlx_1_1core.html#aab3cc7f3808934ae0727b920eba231bd',1,'mlx::core::collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< int64_t > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#a1e0cbcf109d32794ffc8efc7302ba9b0',1,'mlx::core::collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< size_t > &strides, size_t size_cap=std::numeric_limits< int32_t >::max())'],['../namespacemlx_1_1core.html#a4ee50bfb240512d0c0ce151dfe2c74ef',1,'mlx::core::collapse_contiguous_dims(const array &a, size_t size_cap=std::numeric_limits< int32_t >::max())']]],
|
||||
['commandencoder_23',['CommandEncoder',['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#a2334774486f447213ee997e55c2e52a3',1,'mlx::core::metal::CommandEncoder::CommandEncoder(MTL::CommandBuffer *cbuf)'],['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#ac68ca977b5bde5434284ce7979647f14',1,'mlx::core::metal::CommandEncoder::CommandEncoder(const CommandEncoder &)=delete']]],
|
||||
['commit_5fcommand_5fbuffer_24',['commit_command_buffer',['../classmlx_1_1core_1_1metal_1_1_device.html#a95248f1387824067fd4fed23ace5ac0c',1,'mlx::core::metal::Device']]],
|
||||
['communication_5fstream_25',['communication_stream',['../namespacemlx_1_1core_1_1distributed_1_1detail.html#ac3612edf0e0e18c1e4ba0ce7c6e35cd6',1,'mlx::core::distributed::detail']]],
|
||||
['compile_26',['compile',['../namespacemlx_1_1core.html#a3ac798e65e59fe10b7fb5c522efce782',1,'mlx::core::compile()'],['../namespacemlx_1_1core_1_1detail.html#ac3b7b09892ff7290d5f3ef26cb444329',1,'mlx::core::detail::compile(const std::function< std::vector< array >(const std::vector< array > &)> &fun, std::uintptr_t fun_id, bool shapeless=false, std::vector< uint64_t > constants={})']]],
|
||||
['compile_5favailable_5ffor_5fdevice_27',['compile_available_for_device',['../namespacemlx_1_1core_1_1detail.html#aeeff2ba6ec3d9d4ed090de6d2681dbc2',1,'mlx::core::detail']]],
|
||||
['compile_5fclear_5fcache_28',['compile_clear_cache',['../namespacemlx_1_1core_1_1detail.html#a3fb927c209b946aefebb195993fbe4cf',1,'mlx::core::detail']]],
|
||||
['compile_5ferase_29',['compile_erase',['../namespacemlx_1_1core_1_1detail.html#a69eb76a14f845ca000f1ccb2edda0175',1,'mlx::core::detail']]],
|
||||
['compiled_30',['Compiled',['../classmlx_1_1core_1_1_compiled.html#a2d8cefff835c419a48a077d306b8e051',1,'mlx::core::Compiled']]],
|
||||
['compiled_5fallocate_5foutputs_31',['compiled_allocate_outputs',['../namespacemlx_1_1core.html#ab8c3c4fc05745f586de922c8266f4fce',1,'mlx::core']]],
|
||||
['compiled_5fcheck_5fcontiguity_32',['compiled_check_contiguity',['../namespacemlx_1_1core.html#a3b900ab319948c5a01a3ecd30a709027',1,'mlx::core']]],
|
||||
['complex128_5ft_33',['complex128_t',['../structmlx_1_1core_1_1complex128__t.html#aa15d0b805f8790f7c7b76fc7b9d677e0',1,'mlx::core::complex128_t::complex128_t(double v, double u)'],['../structmlx_1_1core_1_1complex128__t.html#abf2842253b874f9f13f39ea68a89e5b6',1,'mlx::core::complex128_t::complex128_t(std::complex< double > v)'],['../structmlx_1_1core_1_1complex128__t.html#a526fba96d7e815360cb4226af085a1bf',1,'mlx::core::complex128_t::complex128_t(T x)']]],
|
||||
['complex64_5ft_34',['complex64_t',['../structcomplex64__t.html#adbd392a5e92d31997380ad0a38be4be8',1,'complex64_t::complex64_t(float real, float imag)'],['../structcomplex64__t.html#a29782289bb90d6294099667b86509cd3',1,'complex64_t::complex64_t()'],['../structcomplex64__t.html#a905b048d70eb8d748a62454268242291',1,'complex64_t::complex64_t() threadgroup'],['../structcomplex64__t.html#a33a2452eb33b5ed53655773539c357a5',1,'complex64_t::complex64_t(T x) thread'],['../structcomplex64__t.html#a89b65ace8588b7bf215355f705eb23d9',1,'complex64_t::complex64_t(T x) threadgroup'],['../structcomplex64__t.html#ac81b486f642fb3b26c5d659917bdbcd0',1,'complex64_t::complex64_t(T x) device'],['../structcomplex64__t.html#a0a27a41206400f1e62b60ceb56960c93',1,'complex64_t::complex64_t(T x) const ant'],['../structmlx_1_1core_1_1complex64__t.html#a697cc973ae27d63c8e00d830e780bd8c',1,'mlx::core::complex64_t::complex64_t(float v, float u)'],['../structmlx_1_1core_1_1complex64__t.html#ae065e39938f9c4374b4116f4c67d4d09',1,'mlx::core::complex64_t::complex64_t(std::complex< float > v)'],['../structmlx_1_1core_1_1complex64__t.html#a2232cbbe591a9d2bc228cb23fac38b50',1,'mlx::core::complex64_t::complex64_t(T x)']]],
|
||||
['complex_5fmul_35',['complex_mul',['../radix_8h.html#a5bfc53b531214c9ce277bebc18aa67d6',1,'radix.h']]],
|
||||
['complex_5fmul_5fconj_36',['complex_mul_conj',['../radix_8h.html#a0e2dfd3d1dda09f47ccc64eec35629f3',1,'radix.h']]],
|
||||
['compute_5fstrided_5findices_37',['compute_strided_indices',['../struct_read_writer.html#a7c903fbb8b85a856ba5564d7df537cdf',1,'ReadWriter']]],
|
||||
['concatenate_38',['Concatenate',['../classmlx_1_1core_1_1_concatenate.html#acff07853de2d31faeec7c4ca40ce0888',1,'mlx::core::Concatenate']]],
|
||||
['concatenate_39',['concatenate',['../group__ops.html#gabdc36fa65697d0361c8d67495de77129',1,'mlx::core::concatenate(const std::vector< array > &arrays, int axis, StreamOrDevice s={})'],['../group__ops.html#gaa95c34ca3a8877f2c50cb60e7fa312b8',1,'mlx::core::concatenate(const std::vector< array > &arrays, StreamOrDevice s={})']]],
|
||||
['concatenate_5fgpu_40',['concatenate_gpu',['../namespacemlx_1_1core.html#a050299d0d366ca5c9d09d1004dcc3e7d',1,'mlx::core']]],
|
||||
['concurrentcontext_41',['ConcurrentContext',['../structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html#aee044d7729739c96e845823f9ecc5174',1,'mlx::core::metal::CommandEncoder::ConcurrentContext']]],
|
||||
['conj_42',['conj',['../namespacepocketfft_1_1detail.html#a66d79051d502046a9b9f103e744dbad3',1,'pocketfft::detail']]],
|
||||
['conjugate_43',['Conjugate',['../classmlx_1_1core_1_1_conjugate.html#a627f9e6a8729fb3ffb3ca3228d007c87',1,'mlx::core::Conjugate']]],
|
||||
['conjugate_44',['conjugate',['../group__ops.html#ga5b596906bf8cdc8d97ed6ddc9aeb4c23',1,'mlx::core']]],
|
||||
['contiguous_5fscan_45',['contiguous_scan',['../scan_8h.html#a60d279b9add7d56639bb209408f09d79',1,'scan.h']]],
|
||||
['contiguousiterator_46',['ContiguousIterator',['../structmlx_1_1core_1_1_contiguous_iterator.html#a68794af4a442d3d8ac4647817af8e1f6',1,'mlx::core::ContiguousIterator::ContiguousIterator()'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a6cb378408b6f546eeb6ade1a4faafe3c',1,'mlx::core::ContiguousIterator::ContiguousIterator(const array &a)'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a16bdacb53f65b7284068cd49d4cba292',1,'mlx::core::ContiguousIterator::ContiguousIterator(const std::vector< int > &shape, const std::vector< StrideT > &strides, int dims)']]],
|
||||
['conv_47',['conv',['../namespacemlx_1_1core_1_1metal.html#ab1704e853394c725668c06752ebb5c24',1,'mlx::core::metal']]],
|
||||
['conv1d_48',['conv1d',['../group__ops.html#ga30d47e08093c03a3676f235f9f559411',1,'mlx::core']]],
|
||||
['conv2d_49',['conv2d',['../group__ops.html#ga73b02833229678786e7f302d458d5a83',1,'mlx::core']]],
|
||||
['conv2dinputblockloadergeneral_50',['Conv2DInputBlockLoaderGeneral',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_general.html#a1d83af561a483432bf8dcb42e734b23b',1,'mlx::steel::Conv2DInputBlockLoaderGeneral']]],
|
||||
['conv2dinputblockloaderlargefilter_51',['Conv2DInputBlockLoaderLargeFilter',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_large_filter.html#a8755116a535539744e4947bc69f9c50f',1,'mlx::steel::Conv2DInputBlockLoaderLargeFilter']]],
|
||||
['conv2dinputblockloadersmallchannels_52',['Conv2DInputBlockLoaderSmallChannels',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_channels.html#ab9fd3fdeab94470dde3326f1dd5c455a',1,'mlx::steel::Conv2DInputBlockLoaderSmallChannels']]],
|
||||
['conv2dinputblockloadersmallfilter_53',['Conv2DInputBlockLoaderSmallFilter',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_filter.html#a0a2cbf57c51cd928722e3f06aafcf933',1,'mlx::steel::Conv2DInputBlockLoaderSmallFilter']]],
|
||||
['conv2dweightblockloader_54',['Conv2DWeightBlockLoader',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader.html#a9a7dca3512b64cffb6eac305d795831c',1,'mlx::steel::Conv2DWeightBlockLoader']]],
|
||||
['conv2dweightblockloadergeneral_55',['Conv2DWeightBlockLoaderGeneral',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_general.html#ad0550fabbdc9297559381a5b488e9af1',1,'mlx::steel::Conv2DWeightBlockLoaderGeneral']]],
|
||||
['conv2dweightblockloadersmallchannels_56',['Conv2DWeightBlockLoaderSmallChannels',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_small_channels.html#ae1806ea1c19713819dee83a38ab35fa6',1,'mlx::steel::Conv2DWeightBlockLoaderSmallChannels']]],
|
||||
['conv3d_57',['conv3d',['../group__ops.html#ga6e9907d2f14dc4803e4306b3dbc4b3ca',1,'mlx::core']]],
|
||||
['conv_5fgeneral_58',['conv_general',['../group__ops.html#ga2236e5dfc7e52e28abf6c21675d0a51e',1,'mlx::core::conv_general(array input, array weight, std::vector< int > stride={}, std::vector< int > padding_lo={}, std::vector< int > padding_hi={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})'],['../group__ops.html#gab59f89942cd1efaadffe9e8762e3c99d',1,'mlx::core::conv_general(const array &input, const array &weight, std::vector< int > stride={}, std::vector< int > padding={}, std::vector< int > kernel_dilation={}, std::vector< int > input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})']]],
|
||||
['conv_5ftranspose1d_59',['conv_transpose1d',['../group__ops.html#gaa30bf1adcd78d1c2595d07b215731714',1,'mlx::core']]],
|
||||
['conv_5ftranspose2d_60',['conv_transpose2d',['../group__ops.html#gaebb59971cb9bc45005dc1d398e4f0a3d',1,'mlx::core']]],
|
||||
['conv_5ftranspose3d_61',['conv_transpose3d',['../group__ops.html#ga8db814da631d9cd32a8d6563bf4ac530',1,'mlx::core']]],
|
||||
['convolution_62',['Convolution',['../classmlx_1_1core_1_1_convolution.html#a6f1de77b719bb13217b0d8c64cabb8ef',1,'mlx::core::Convolution']]],
|
||||
['copy_63',['Copy',['../classmlx_1_1core_1_1_copy.html#a6243e044af119105ffaaed7d405cd584',1,'mlx::core::Copy']]],
|
||||
['copy_64',['copy',['../namespacemlx_1_1core.html#a479648542a2bea151b947b18f0e79dd2',1,'mlx::core::copy()'],['../namespacemlx_1_1core_1_1metal.html#aa215e631e2680f04a591b88d91571719',1,'mlx::core::metal::copy()'],['../group__ops.html#gae306e93af12f774bd80bad6c231b09d6',1,'mlx::core::copy()']]],
|
||||
['copy_5fg_65',['copy_g',['../metal_2kernels_2copy_8h.html#a778ce2dbfbaa23b24bd5efbe68448c36',1,'copy.h']]],
|
||||
['copy_5fg_5fnd1_66',['copy_g_nd1',['../metal_2kernels_2copy_8h.html#aba4530a7db6a61ca36f50e4f5e58fb77',1,'copy.h']]],
|
||||
['copy_5fg_5fnd2_67',['copy_g_nd2',['../metal_2kernels_2copy_8h.html#aee678c7c31119f3e609685589f37490c',1,'copy.h']]],
|
||||
['copy_5fg_5fnd3_68',['copy_g_nd3',['../metal_2kernels_2copy_8h.html#a821f8f3f3891159a295c66fc25aed1ff',1,'copy.h']]],
|
||||
['copy_5fgg_69',['copy_gg',['../metal_2kernels_2copy_8h.html#a1e39c2683eeaf05955e7619fbd34aea5',1,'copy.h']]],
|
||||
['copy_5fgg_5fnd1_70',['copy_gg_nd1',['../metal_2kernels_2copy_8h.html#a3278d9c999718bee3ccbe2922f501bf1',1,'copy.h']]],
|
||||
['copy_5fgg_5fnd2_71',['copy_gg_nd2',['../metal_2kernels_2copy_8h.html#a3e2d3cc7f34f56170409b6735f51a950',1,'copy.h']]],
|
||||
['copy_5fgg_5fnd3_72',['copy_gg_nd3',['../metal_2kernels_2copy_8h.html#a59f43b5bffed936d7559ceb06a10aabd',1,'copy.h']]],
|
||||
['copy_5fgpu_73',['copy_gpu',['../namespacemlx_1_1core.html#addaa46a13ac2deb1d9ce621338320e0e',1,'mlx::core::copy_gpu(const array &src, array &out, CopyType ctype, const Stream &s)'],['../namespacemlx_1_1core.html#a6a6f4e46c8fc44fdc74c50ace02bcf38',1,'mlx::core::copy_gpu(const array &src, array &out, CopyType ctype)']]],
|
||||
['copy_5fgpu_5finplace_74',['copy_gpu_inplace',['../namespacemlx_1_1core.html#a69e30f5d30a6d72ac0ffe4886f24b7ba',1,'mlx::core::copy_gpu_inplace(const array &in, array &out, const std::vector< int > &data_shape, const std::vector< stride_t > &i_strides, const std::vector< stride_t > &o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype, const Stream &s)'],['../namespacemlx_1_1core.html#a8e1ccb0ed9387b0a789311d9f8964803',1,'mlx::core::copy_gpu_inplace(const array &src, array &out, CopyType ctype, const Stream &s)'],['../namespacemlx_1_1core.html#ae55b801b09ccf55cba96278163a9b1ef',1,'mlx::core::copy_gpu_inplace(const array &in, array &out, const std::vector< int64_t > &istride, int64_t ioffset, CopyType ctype, const Stream &s)']]],
|
||||
['copy_5fhartley_75',['copy_hartley',['../namespacepocketfft_1_1detail.html#abac3fcc8ce83800d228774f64c28d4c3',1,'pocketfft::detail::copy_hartley(const multi_iter< vlen > &it, const vtype_t< T > *src, ndarr< T > &dst)'],['../namespacepocketfft_1_1detail.html#ae7b44d2773d9d06a9787aff01d66b3ed',1,'pocketfft::detail::copy_hartley(const multi_iter< vlen > &it, const T *src, ndarr< T > &dst)']]],
|
||||
['copy_5finplace_76',['copy_inplace',['../namespacemlx_1_1core.html#a98495894a796b2cc6d022e7a03432c64',1,'mlx::core::copy_inplace(const array &src, array &dst, CopyType ctype)'],['../namespacemlx_1_1core.html#aad636e2d0b2f882cadd1b438f4daa9ed',1,'mlx::core::copy_inplace(const array &src, array &dst, const std::vector< int > &data_shape, const std::vector< stride_t > &i_strides, const std::vector< stride_t > &o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype)']]],
|
||||
['copy_5finput_77',['copy_input',['../namespacepocketfft_1_1detail.html#aff05be3064743c1143b19318ab12ad4a',1,'pocketfft::detail::copy_input(const multi_iter< vlen > &it, const cndarr< cmplx< T > > &src, cmplx< vtype_t< T > > *dst)'],['../namespacepocketfft_1_1detail.html#a30fc708f9d8f9cfa74194925c7863c0a',1,'pocketfft::detail::copy_input(const multi_iter< vlen > &it, const cndarr< T > &src, vtype_t< T > *dst)'],['../namespacepocketfft_1_1detail.html#a3387bd35f237870e42b8461769e6aec4',1,'pocketfft::detail::copy_input(const multi_iter< vlen > &it, const cndarr< T > &src, T *dst)']]],
|
||||
['copy_5foutput_78',['copy_output',['../namespacepocketfft_1_1detail.html#a1523a037300a8da05db210b802d9cb0e',1,'pocketfft::detail::copy_output(const multi_iter< vlen > &it, const cmplx< vtype_t< T > > *src, ndarr< cmplx< T > > &dst)'],['../namespacepocketfft_1_1detail.html#a21980853aca4d92ed06e3dcffe7ef660',1,'pocketfft::detail::copy_output(const multi_iter< vlen > &it, const vtype_t< T > *src, ndarr< T > &dst)'],['../namespacepocketfft_1_1detail.html#a310481c334e46674710ba794ad7403c0',1,'pocketfft::detail::copy_output(const multi_iter< vlen > &it, const T *src, ndarr< T > &dst)']]],
|
||||
['copy_5fs_79',['copy_s',['../metal_2kernels_2copy_8h.html#aef09f9b9475345b1bba121d037d222ea',1,'copy.h']]],
|
||||
['copy_5fs2_80',['copy_s2',['../metal_2kernels_2copy_8h.html#a8023e9335cc5334847a8d315042be3a3',1,'copy.h']]],
|
||||
['copy_5fshared_5fbuffer_81',['copy_shared_buffer',['../classmlx_1_1core_1_1array.html#a28df7a333d90a311c49bc4bce7a1ad6d',1,'mlx::core::array::copy_shared_buffer(const array &other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)'],['../classmlx_1_1core_1_1array.html#a92974c656c35a972ad241f80584bbd29',1,'mlx::core::array::copy_shared_buffer(const array &other)']]],
|
||||
['copy_5fv_82',['copy_v',['../metal_2kernels_2copy_8h.html#ae26a13e0c8e6c15f7b10078e65970659',1,'copy.h']]],
|
||||
['copy_5fv2_83',['copy_v2',['../metal_2kernels_2copy_8h.html#aee14a5326f53d9b30b0b38e27d180ef3',1,'copy.h']]],
|
||||
['cos_84',['Cos',['../classmlx_1_1core_1_1_cos.html#a2acb9fcf0901462189c476756fd99995',1,'mlx::core::Cos']]],
|
||||
['cos_85',['cos',['../namespacepocketfft_1_1detail.html#a499c1e8b7d79a5272af024f46c63ff9d',1,'pocketfft::detail::cos()'],['../namespacemetal.html#a2fa4778a6fe2fa43253ea724e5a608a3',1,'metal::cos()'],['../namespacemetal_1_1fast.html#a75b6bb32fa3870eda46a7bfc9f481f88',1,'metal::fast::cos()'],['../namespacemetal_1_1precise.html#ac4941f62e7d8ab9d7cabbd967aa9f220',1,'metal::precise::cos()'],['../group__ops.html#ga39dfdf72b556012aa35ff27a94116e74',1,'mlx::core::cos()']]],
|
||||
['cosh_86',['Cosh',['../classmlx_1_1core_1_1_cosh.html#a44e8ac2e09a55ec32e9dc6641eedc8f1',1,'mlx::core::Cosh']]],
|
||||
['cosh_87',['cosh',['../namespacemetal.html#a8a68a88cc110830d057dbd71431b93c0',1,'metal::cosh()'],['../namespacemetal_1_1fast.html#a31544ad9de28012a4ddda86e3966a77e',1,'metal::fast::cosh()'],['../namespacemetal_1_1precise.html#a72d86d508300a9b58f4ccbbe70da4fbc',1,'metal::precise::cosh()'],['../group__ops.html#ga2181b71cda88007a3092be4795ff0715',1,'mlx::core::cosh()']]],
|
||||
['cospi_88',['cospi',['../namespacemetal.html#a5c2f37939ad705ddea4409d3bedb8ce1',1,'metal::cospi()'],['../namespacemetal_1_1fast.html#a9906b41f75319b384ffb570cc94d67ce',1,'metal::fast::cospi()'],['../namespacemetal_1_1precise.html#a2392b78bd196efdbbac65901c4ab20e7',1,'metal::precise::cospi()']]],
|
||||
['cost_5fguess_89',['cost_guess',['../structpocketfft_1_1detail_1_1util.html#ad3d874bc3fb0048df2270779a15d4bd0',1,'pocketfft::detail::util']]],
|
||||
['count_5fdown_90',['count_down',['../classpocketfft_1_1detail_1_1threading_1_1latch.html#a81d6597189b40410e35f3cd653fd1342',1,'pocketfft::detail::threading::latch']]],
|
||||
['cross_91',['cross',['../namespacemlx_1_1core_1_1linalg.html#abcda3fbda45183c21e7f27aa0dde64e6',1,'mlx::core::linalg']]],
|
||||
['cummax_92',['cummax',['../group__ops.html#gaee37cac8476e8f8d666bcded5bc59143',1,'mlx::core']]],
|
||||
['cummin_93',['cummin',['../group__ops.html#ga19c1bf6929fe8d66b9cd408946aea6a8',1,'mlx::core']]],
|
||||
['cumprod_94',['cumprod',['../group__ops.html#ga0d71dfbc14ef3ed564b0c5ee26af680f',1,'mlx::core']]],
|
||||
['cumsum_95',['cumsum',['../group__ops.html#gaddc825a5c173e195ab0fda83ad630420',1,'mlx::core']]],
|
||||
['custom_96',['Custom',['../classmlx_1_1core_1_1fast_1_1_custom.html#a4186fea23f7156c38960426821fca313',1,'mlx::core::fast::Custom']]],
|
||||
['custom_5ffunction_97',['custom_function',['../namespacemlx_1_1core.html#a8d3ca5fbaecdb995660c24cde5aeebaf',1,'mlx::core']]],
|
||||
['custom_5fvjp_98',['custom_vjp',['../namespacemlx_1_1core.html#a9290596250fa308df4c69b44483bb8aa',1,'mlx::core']]],
|
||||
['customkernel_99',['CustomKernel',['../classmlx_1_1core_1_1fast_1_1_custom_kernel.html#a954893e07f0d36715b4e1e414b6f2153',1,'mlx::core::fast::CustomKernel']]],
|
||||
['customtransforms_100',['CustomTransforms',['../classmlx_1_1core_1_1_custom_transforms.html#ab52abadb9c6f6db83d087c7b751be488',1,'mlx::core::CustomTransforms']]]
|
||||
];
|
||||
|
2
docs/build/html/search/functions_7.js
vendored
2
docs/build/html/search/functions_7.js
vendored
@ -44,7 +44,7 @@ var searchData=
|
||||
['get_5fpool_41',['get_pool',['../namespacepocketfft_1_1detail_1_1threading.html#a7ec2b3f99232bd0f15f7b022c59d139a',1,'pocketfft::detail::threading']]],
|
||||
['get_5fprimitive_5fstring_42',['get_primitive_string',['../namespacemlx_1_1core.html#ad4be35b310a252edd80d9cf04f094a60',1,'mlx::core']]],
|
||||
['get_5fquantized_5fkernel_43',['get_quantized_kernel',['../namespacemlx_1_1core.html#aa3faeae5378bfaafe3ce3432a051e43e',1,'mlx::core']]],
|
||||
['get_5freduce_5finit_5fkernel_44',['get_reduce_init_kernel',['../namespacemlx_1_1core.html#a51c4bb09230348bd0252e22bfdc9bc89',1,'mlx::core']]],
|
||||
['get_5freduce_5finit_5fkernel_44',['get_reduce_init_kernel',['../namespacemlx_1_1core.html#a3bd386cb6db09f636963ce66ceaf8647',1,'mlx::core']]],
|
||||
['get_5freduce_5fkernel_45',['get_reduce_kernel',['../namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b',1,'mlx::core']]],
|
||||
['get_5freduction_5fplan_46',['get_reduction_plan',['../namespacemlx_1_1core.html#ac97b5a6f009ca3d99854ce9512c20dba',1,'mlx::core']]],
|
||||
['get_5fscan_5fkernel_47',['get_scan_kernel',['../namespacemlx_1_1core.html#aeefaff208444d3fa61ecc0946fe1de5f',1,'mlx::core']]],
|
||||
|
71
docs/build/html/search/functions_d.js
vendored
71
docs/build/html/search/functions_d.js
vendored
@ -14,39 +14,40 @@ var searchData=
|
||||
['max3_11',['max3',['../namespacemetal.html#a00f9c0ad66d969794614f56912eed9c9',1,'metal::max3()'],['../namespacemetal_1_1fast.html#a6fc2cf18ffa8149561864c86dba0f803',1,'metal::fast::max3()'],['../namespacemetal_1_1precise.html#ac490e8614ebd2c9343af1ae6c0d4e82c',1,'metal::precise::max3()']]],
|
||||
['maximum_12',['Maximum',['../classmlx_1_1core_1_1_maximum.html#a28389307e385efe1b2955b86b115e816',1,'mlx::core::Maximum']]],
|
||||
['maximum_13',['maximum',['../group__ops.html#ga7ade2ea305e2e4219c3609443fb5db8d',1,'mlx::core']]],
|
||||
['mb_5fblock_5fmerge_14',['mb_block_merge',['../sort_8h.html#ab381cd57f344bc7304ab580bfdc78807',1,'sort.h']]],
|
||||
['mb_5fblock_5fpartition_15',['mb_block_partition',['../sort_8h.html#a32cbe4163b8b0f5cb2c97b256119a4b2',1,'sort.h']]],
|
||||
['mb_5fblock_5fsort_16',['mb_block_sort',['../sort_8h.html#aa48ff1aff1e9dc1301b6781aa0721d6b',1,'sort.h']]],
|
||||
['mean_17',['mean',['../group__ops.html#gade46e768fd46b8b640eb16f26abeecef',1,'mlx::core::mean(const array &a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga52b59fdd8e8430538e564f5bbcfa31e6',1,'mlx::core::mean(const array &a, StreamOrDevice s={})'],['../group__ops.html#ga066161f3d3e395a1d76c638cb680d444',1,'mlx::core::mean(const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#ga45fba73eab0e3b6e128ed3ce2f43a5da',1,'mlx::core::mean(const array &a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
|
||||
['median3_18',['median3',['../namespacemetal.html#aa3ff49457ce3c93fc1c0897fd1525157',1,'metal::median3()'],['../namespacemetal_1_1fast.html#a742b55f1e4369921ee7f60d70185bfbc',1,'metal::fast::median3()'],['../namespacemetal_1_1precise.html#a14555ff99c4388493fec48e070144ae2',1,'metal::precise::median3()']]],
|
||||
['merge_5fpartition_19',['merge_partition',['../struct_block_merge_sort.html#ab2300cbecb23f3433bad888924c831ca',1,'BlockMergeSort::merge_partition()'],['../struct_kernel_multi_block_merge_sort.html#ab15895b4233aba0e279cc44a07a201fe',1,'KernelMultiBlockMergeSort::merge_partition()']]],
|
||||
['merge_5fstep_20',['merge_step',['../struct_block_merge_sort.html#ab65f190edf1851b37c39ad49ce99a43c',1,'BlockMergeSort']]],
|
||||
['meshgrid_21',['meshgrid',['../group__ops.html#ga577c911618575314de63d1060656a26e',1,'mlx::core']]],
|
||||
['metal_5fkernel_22',['metal_kernel',['../namespacemlx_1_1core_1_1fast.html#ab16436b465dc10ce472193d541d8426e',1,'mlx::core::fast']]],
|
||||
['min_23',['min',['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#adaed80031f5ca0ff69d30ec4c5d0c98f',1,'metal::_numeric_limits_impl< bfloat16_t >::min()'],['../namespacemetal.html#a6653b28c9473087141eddce39878d4d3',1,'metal::min()'],['../namespacemetal_1_1fast.html#a3e958e56a4712687c381a0b64d123e61',1,'metal::fast::min()'],['../namespacemetal_1_1precise.html#afed0da2f7df3505b5dffa2389c3cb36e',1,'metal::precise::min()'],['../group__ops.html#gab27599802617a4c8f9964ab5f4ffee12',1,'mlx::core::min(const array &a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga0140b91e9cdfc3fef0da8e332f65a9e8',1,'mlx::core::min(const array &a, StreamOrDevice s={})'],['../group__ops.html#ga6efb83cd46436678c8f8c4af15cc00f5',1,'mlx::core::min(const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#ga36fa315eef677f4143868f552cd26d03',1,'mlx::core::min(const array &a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
|
||||
['min3_24',['min3',['../namespacemetal.html#a005510c8c0f964ce2b8aad3ba76a7a3f',1,'metal::min3()'],['../namespacemetal_1_1fast.html#a606a4c1b34ce05ea89ca5af81724036f',1,'metal::fast::min3()'],['../namespacemetal_1_1precise.html#a4d37ce31c3549ca4772a4ee29798e231',1,'metal::precise::min3()']]],
|
||||
['minimum_25',['Minimum',['../classmlx_1_1core_1_1_minimum.html#ab0f2ce17108df44b82cff68886b0f6f5',1,'mlx::core::Minimum']]],
|
||||
['minimum_26',['minimum',['../group__ops.html#ga49ba00c090f81f331c91b0c97040bce0',1,'mlx::core']]],
|
||||
['mlx_5fatomic_5fcompare_5fexchange_5fweak_5fexplicit_27',['mlx_atomic_compare_exchange_weak_explicit',['../atomic_8h.html#ad7f32327ff66354cfa2f0cfdac79316f',1,'mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic< T > *object, thread T *expected, T val, size_t offset): atomic.h'],['../atomic_8h.html#aa8f47b2e9b95d4b00ad51f08b070deb5',1,'mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic< T > *object, thread uint *expected, uint val, size_t offset): atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fadd_5fexplicit_28',['mlx_atomic_fetch_add_explicit',['../atomic_8h.html#aad448d9e06e001700b65ca8317216a3b',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fand_5fexplicit_29',['mlx_atomic_fetch_and_explicit',['../atomic_8h.html#a253e3c870c0ddc7c28ab2f6ca2c3eae5',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmax_5fexplicit_30',['mlx_atomic_fetch_max_explicit',['../atomic_8h.html#ac480f2b459a8ad9095cee353e152d00c',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmax_5fexplicit_3c_20float_20_3e_31',['mlx_atomic_fetch_max_explicit< float >',['../atomic_8h.html#a1dce2abfa16417122c4d2bf261129ae4',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmin_5fexplicit_32',['mlx_atomic_fetch_min_explicit',['../atomic_8h.html#a2ec33dca0039bd944d73d1c2b378cc19',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmin_5fexplicit_3c_20float_20_3e_33',['mlx_atomic_fetch_min_explicit< float >',['../atomic_8h.html#ab7d1dc49f319f239b7ee0b7c72976dd0',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmul_5fexplicit_34',['mlx_atomic_fetch_mul_explicit',['../atomic_8h.html#adfdbea60436f14f1af9ce36e2a0a77a3',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5for_5fexplicit_35',['mlx_atomic_fetch_or_explicit',['../atomic_8h.html#ab7391f197001471e4788312bdb6ab37a',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5fload_5fexplicit_36',['mlx_atomic_load_explicit',['../atomic_8h.html#a253a4e8c2c5768a069e2791b627dfc99',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5fstore_5fexplicit_37',['mlx_atomic_store_explicit',['../atomic_8h.html#a0ae453140b0819a4c02f265334de98c0',1,'atomic.h']]],
|
||||
['mma_38',['mma',['../structmlx_1_1steel_1_1_base_m_m_a_frag_3_01_t_00_018_00_018_01_4.html#a8028512f5a3d2b6acaf966be529627a3',1,'mlx::steel::BaseMMAFrag< T, 8, 8 >::mma(thread frag_type &D, thread frag_type &A, thread frag_type &B, thread frag_type &C)'],['../structmlx_1_1steel_1_1_base_m_m_a_frag_3_01_t_00_018_00_018_01_4.html#a1868f57d57c8adedab2c58492ec76946',1,'mlx::steel::BaseMMAFrag< T, 8, 8 >::mma(thread mat_type &D, thread mat_type &A, thread mat_type &B, thread mat_type &C)'],['../structmlx_1_1steel_1_1_block_m_m_a.html#a6a2c2a6d5e767d52c41b42a9d36086b0',1,'mlx::steel::BlockMMA::mma()']]],
|
||||
['mmatile_39',['MMATile',['../structmlx_1_1steel_1_1_m_m_a_tile.html#aa3fb310dd08ec23c334511f7b316d1b6',1,'mlx::steel::MMATile']]],
|
||||
['move_5fshared_5fbuffer_40',['move_shared_buffer',['../classmlx_1_1core_1_1array.html#acce00db63e0f3d80f797b02397ade836',1,'mlx::core::array::move_shared_buffer(array other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)'],['../classmlx_1_1core_1_1array.html#a38d7ad605f8282e5e49d0c09e0555c78',1,'mlx::core::array::move_shared_buffer(array other)']]],
|
||||
['moveaxis_41',['moveaxis',['../group__ops.html#ga24067d10a842db2c9d509ea48135a2c3',1,'mlx::core']]],
|
||||
['mpinplace_42',['MPINPLACE',['../namespacepocketfft_1_1detail.html#af5eedf3cdfc83c0a30807092c39a9ce2',1,'pocketfft::detail']]],
|
||||
['mtl_5fdevice_43',['mtl_device',['../classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653',1,'mlx::core::metal::Device']]],
|
||||
['mtl_5fresidency_5fset_44',['mtl_residency_set',['../classmlx_1_1core_1_1metal_1_1_residency_set.html#ac4bfe5ef5e2eaebc458a1ed1953d15e9',1,'mlx::core::metal::ResidencySet']]],
|
||||
['multi_5fiter_45',['multi_iter',['../classpocketfft_1_1detail_1_1multi__iter.html#a9be43bb18840202da6d17988fccc64b9',1,'pocketfft::detail::multi_iter']]],
|
||||
['multiply_46',['Multiply',['../classmlx_1_1core_1_1_multiply.html#aca5c50f900321f3eb4d6fbcbc225c00c',1,'mlx::core::Multiply']]],
|
||||
['multiply_47',['multiply',['../group__ops.html#gaf57392e641640b5d06e4c99518391c38',1,'mlx::core']]],
|
||||
['multivariate_5fnormal_48',['multivariate_normal',['../namespacemlx_1_1core_1_1random.html#a8c37da3c1c0c561cad7499d6d9db81fb',1,'mlx::core::random']]]
|
||||
['maybeinsertbarrier_14',['maybeInsertBarrier',['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#ad538ae88f90560063f9ba502e2795991',1,'mlx::core::metal::CommandEncoder']]],
|
||||
['mb_5fblock_5fmerge_15',['mb_block_merge',['../sort_8h.html#ab381cd57f344bc7304ab580bfdc78807',1,'sort.h']]],
|
||||
['mb_5fblock_5fpartition_16',['mb_block_partition',['../sort_8h.html#a32cbe4163b8b0f5cb2c97b256119a4b2',1,'sort.h']]],
|
||||
['mb_5fblock_5fsort_17',['mb_block_sort',['../sort_8h.html#aa48ff1aff1e9dc1301b6781aa0721d6b',1,'sort.h']]],
|
||||
['mean_18',['mean',['../group__ops.html#gade46e768fd46b8b640eb16f26abeecef',1,'mlx::core::mean(const array &a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga52b59fdd8e8430538e564f5bbcfa31e6',1,'mlx::core::mean(const array &a, StreamOrDevice s={})'],['../group__ops.html#ga066161f3d3e395a1d76c638cb680d444',1,'mlx::core::mean(const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#ga45fba73eab0e3b6e128ed3ce2f43a5da',1,'mlx::core::mean(const array &a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
|
||||
['median3_19',['median3',['../namespacemetal.html#aa3ff49457ce3c93fc1c0897fd1525157',1,'metal::median3()'],['../namespacemetal_1_1fast.html#a742b55f1e4369921ee7f60d70185bfbc',1,'metal::fast::median3()'],['../namespacemetal_1_1precise.html#a14555ff99c4388493fec48e070144ae2',1,'metal::precise::median3()']]],
|
||||
['merge_5fpartition_20',['merge_partition',['../struct_block_merge_sort.html#ab2300cbecb23f3433bad888924c831ca',1,'BlockMergeSort::merge_partition()'],['../struct_kernel_multi_block_merge_sort.html#ab15895b4233aba0e279cc44a07a201fe',1,'KernelMultiBlockMergeSort::merge_partition()']]],
|
||||
['merge_5fstep_21',['merge_step',['../struct_block_merge_sort.html#ab65f190edf1851b37c39ad49ce99a43c',1,'BlockMergeSort']]],
|
||||
['meshgrid_22',['meshgrid',['../group__ops.html#ga577c911618575314de63d1060656a26e',1,'mlx::core']]],
|
||||
['metal_5fkernel_23',['metal_kernel',['../namespacemlx_1_1core_1_1fast.html#ab16436b465dc10ce472193d541d8426e',1,'mlx::core::fast']]],
|
||||
['min_24',['min',['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#adaed80031f5ca0ff69d30ec4c5d0c98f',1,'metal::_numeric_limits_impl< bfloat16_t >::min()'],['../namespacemetal.html#a6653b28c9473087141eddce39878d4d3',1,'metal::min()'],['../namespacemetal_1_1fast.html#a3e958e56a4712687c381a0b64d123e61',1,'metal::fast::min()'],['../namespacemetal_1_1precise.html#afed0da2f7df3505b5dffa2389c3cb36e',1,'metal::precise::min()'],['../group__ops.html#gab27599802617a4c8f9964ab5f4ffee12',1,'mlx::core::min(const array &a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga0140b91e9cdfc3fef0da8e332f65a9e8',1,'mlx::core::min(const array &a, StreamOrDevice s={})'],['../group__ops.html#ga6efb83cd46436678c8f8c4af15cc00f5',1,'mlx::core::min(const array &a, const std::vector< int > &axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#ga36fa315eef677f4143868f552cd26d03',1,'mlx::core::min(const array &a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
|
||||
['min3_25',['min3',['../namespacemetal.html#a005510c8c0f964ce2b8aad3ba76a7a3f',1,'metal::min3()'],['../namespacemetal_1_1fast.html#a606a4c1b34ce05ea89ca5af81724036f',1,'metal::fast::min3()'],['../namespacemetal_1_1precise.html#a4d37ce31c3549ca4772a4ee29798e231',1,'metal::precise::min3()']]],
|
||||
['minimum_26',['Minimum',['../classmlx_1_1core_1_1_minimum.html#ab0f2ce17108df44b82cff68886b0f6f5',1,'mlx::core::Minimum']]],
|
||||
['minimum_27',['minimum',['../group__ops.html#ga49ba00c090f81f331c91b0c97040bce0',1,'mlx::core']]],
|
||||
['mlx_5fatomic_5fcompare_5fexchange_5fweak_5fexplicit_28',['mlx_atomic_compare_exchange_weak_explicit',['../atomic_8h.html#ad7f32327ff66354cfa2f0cfdac79316f',1,'mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic< T > *object, thread T *expected, T val, size_t offset): atomic.h'],['../atomic_8h.html#aa8f47b2e9b95d4b00ad51f08b070deb5',1,'mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic< T > *object, thread uint *expected, uint val, size_t offset): atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fadd_5fexplicit_29',['mlx_atomic_fetch_add_explicit',['../atomic_8h.html#aad448d9e06e001700b65ca8317216a3b',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fand_5fexplicit_30',['mlx_atomic_fetch_and_explicit',['../atomic_8h.html#a253e3c870c0ddc7c28ab2f6ca2c3eae5',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmax_5fexplicit_31',['mlx_atomic_fetch_max_explicit',['../atomic_8h.html#ac480f2b459a8ad9095cee353e152d00c',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmax_5fexplicit_3c_20float_20_3e_32',['mlx_atomic_fetch_max_explicit< float >',['../atomic_8h.html#a1dce2abfa16417122c4d2bf261129ae4',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmin_5fexplicit_33',['mlx_atomic_fetch_min_explicit',['../atomic_8h.html#a2ec33dca0039bd944d73d1c2b378cc19',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmin_5fexplicit_3c_20float_20_3e_34',['mlx_atomic_fetch_min_explicit< float >',['../atomic_8h.html#ab7d1dc49f319f239b7ee0b7c72976dd0',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5fmul_5fexplicit_35',['mlx_atomic_fetch_mul_explicit',['../atomic_8h.html#adfdbea60436f14f1af9ce36e2a0a77a3',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5ffetch_5for_5fexplicit_36',['mlx_atomic_fetch_or_explicit',['../atomic_8h.html#ab7391f197001471e4788312bdb6ab37a',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5fload_5fexplicit_37',['mlx_atomic_load_explicit',['../atomic_8h.html#a253a4e8c2c5768a069e2791b627dfc99',1,'atomic.h']]],
|
||||
['mlx_5fatomic_5fstore_5fexplicit_38',['mlx_atomic_store_explicit',['../atomic_8h.html#a0ae453140b0819a4c02f265334de98c0',1,'atomic.h']]],
|
||||
['mma_39',['mma',['../structmlx_1_1steel_1_1_base_m_m_a_frag_3_01_t_00_018_00_018_01_4.html#a8028512f5a3d2b6acaf966be529627a3',1,'mlx::steel::BaseMMAFrag< T, 8, 8 >::mma(thread frag_type &D, thread frag_type &A, thread frag_type &B, thread frag_type &C)'],['../structmlx_1_1steel_1_1_base_m_m_a_frag_3_01_t_00_018_00_018_01_4.html#a1868f57d57c8adedab2c58492ec76946',1,'mlx::steel::BaseMMAFrag< T, 8, 8 >::mma(thread mat_type &D, thread mat_type &A, thread mat_type &B, thread mat_type &C)'],['../structmlx_1_1steel_1_1_block_m_m_a.html#a6a2c2a6d5e767d52c41b42a9d36086b0',1,'mlx::steel::BlockMMA::mma()']]],
|
||||
['mmatile_40',['MMATile',['../structmlx_1_1steel_1_1_m_m_a_tile.html#aa3fb310dd08ec23c334511f7b316d1b6',1,'mlx::steel::MMATile']]],
|
||||
['move_5fshared_5fbuffer_41',['move_shared_buffer',['../classmlx_1_1core_1_1array.html#acce00db63e0f3d80f797b02397ade836',1,'mlx::core::array::move_shared_buffer(array other, const std::vector< size_t > &strides, Flags flags, size_t data_size, size_t offset=0)'],['../classmlx_1_1core_1_1array.html#a38d7ad605f8282e5e49d0c09e0555c78',1,'mlx::core::array::move_shared_buffer(array other)']]],
|
||||
['moveaxis_42',['moveaxis',['../group__ops.html#ga24067d10a842db2c9d509ea48135a2c3',1,'mlx::core']]],
|
||||
['mpinplace_43',['MPINPLACE',['../namespacepocketfft_1_1detail.html#af5eedf3cdfc83c0a30807092c39a9ce2',1,'pocketfft::detail']]],
|
||||
['mtl_5fdevice_44',['mtl_device',['../classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653',1,'mlx::core::metal::Device']]],
|
||||
['mtl_5fresidency_5fset_45',['mtl_residency_set',['../classmlx_1_1core_1_1metal_1_1_residency_set.html#ac4bfe5ef5e2eaebc458a1ed1953d15e9',1,'mlx::core::metal::ResidencySet']]],
|
||||
['multi_5fiter_46',['multi_iter',['../classpocketfft_1_1detail_1_1multi__iter.html#a9be43bb18840202da6d17988fccc64b9',1,'pocketfft::detail::multi_iter']]],
|
||||
['multiply_47',['Multiply',['../classmlx_1_1core_1_1_multiply.html#aca5c50f900321f3eb4d6fbcbc225c00c',1,'mlx::core::Multiply']]],
|
||||
['multiply_48',['multiply',['../group__ops.html#gaf57392e641640b5d06e4c99518391c38',1,'mlx::core']]],
|
||||
['multivariate_5fnormal_49',['multivariate_normal',['../namespacemlx_1_1core_1_1random.html#a8c37da3c1c0c561cad7499d6d9db81fb',1,'mlx::core::random']]]
|
||||
];
|
||||
|
2
docs/build/html/searchindex.js
vendored
2
docs/build/html/searchindex.js
vendored
File diff suppressed because one or more lines are too long
@ -99,13 +99,14 @@ $(function(){ initResizable(false); });
|
||||
<tr class="even"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a74bcd8e35f80f5a62db48c4a2bb0173e">dispatchThreadgroups</a>(MTL::Size grid_dims, MTL::Size group_dims)</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
|
||||
<tr class="odd"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a1e41477f2f489e38499f7830a91c9810">dispatchThreads</a>(MTL::Size grid_dims, MTL::Size group_dims)</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
|
||||
<tr class="even"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a27ded7e54bc1712063c874646b445509">inputs</a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr class="odd"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aac45ab0630ea32cf7d15c7ba3e229966">operator-></a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr class="even"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a3f42a1362b4a513fa89e7b3dcc570a8e">operator=</a>(const CommandEncoder &)=delete</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
|
||||
<tr class="odd"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aefa48740fdee884f02e2d379bca4e78f">outputs</a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr class="even"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ab69ff0d7f14b9b59db4df0608193dce4">set_input_array</a>(const array &a, int idx, int64_t offset=0)</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
|
||||
<tr class="odd"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a6a2e28e542eaa2886041bddd51ff6522">set_output_array</a>(array &a, int idx, int64_t offset=0)</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
|
||||
<tr class="even"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a48b548a0b15f9d1279c938a1c6167034">start_concurrent</a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr class="odd"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a9b6dd221ccd2d939d544004cb6279198">~CommandEncoder</a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
|
||||
<tr class="odd"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ad538ae88f90560063f9ba502e2795991">maybeInsertBarrier</a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
|
||||
<tr class="even"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aac45ab0630ea32cf7d15c7ba3e229966">operator-></a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr class="odd"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a3f42a1362b4a513fa89e7b3dcc570a8e">operator=</a>(const CommandEncoder &)=delete</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
|
||||
<tr class="even"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aefa48740fdee884f02e2d379bca4e78f">outputs</a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr class="odd"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ab69ff0d7f14b9b59db4df0608193dce4">set_input_array</a>(const array &a, int idx, int64_t offset=0)</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
|
||||
<tr class="even"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a6a2e28e542eaa2886041bddd51ff6522">set_output_array</a>(array &a, int idx, int64_t offset=0)</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
|
||||
<tr class="odd"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a48b548a0b15f9d1279c938a1c6167034">start_concurrent</a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr class="even"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a9b6dd221ccd2d939d544004cb6279198">~CommandEncoder</a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
|
||||
</table></div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
<hr class="footer"/><address class="footer"><small>
|
||||
|
@ -121,6 +121,8 @@ Public Member Functions</h2></td></tr>
|
||||
<tr class="separator:a74bcd8e35f80f5a62db48c4a2bb0173e"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a1e41477f2f489e38499f7830a91c9810" id="r_a1e41477f2f489e38499f7830a91c9810"><td class="memItemLeft" align="right" valign="top">void </td><td class="memItemRight" valign="bottom"><a class="el" href="#a1e41477f2f489e38499f7830a91c9810">dispatchThreads</a> (MTL::Size grid_dims, MTL::Size group_dims)</td></tr>
|
||||
<tr class="separator:a1e41477f2f489e38499f7830a91c9810"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:ad538ae88f90560063f9ba502e2795991" id="r_ad538ae88f90560063f9ba502e2795991"><td class="memItemLeft" align="right" valign="top">void </td><td class="memItemRight" valign="bottom"><a class="el" href="#ad538ae88f90560063f9ba502e2795991">maybeInsertBarrier</a> ()</td></tr>
|
||||
<tr class="separator:ad538ae88f90560063f9ba502e2795991"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a48b548a0b15f9d1279c938a1c6167034" id="r_a48b548a0b15f9d1279c938a1c6167034"><td class="memItemLeft" align="right" valign="top"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html">ConcurrentContext</a> </td><td class="memItemRight" valign="bottom"><a class="el" href="#a48b548a0b15f9d1279c938a1c6167034">start_concurrent</a> ()</td></tr>
|
||||
<tr class="separator:a48b548a0b15f9d1279c938a1c6167034"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a9b6dd221ccd2d939d544004cb6279198" id="r_a9b6dd221ccd2d939d544004cb6279198"><td class="memItemLeft" align="right" valign="top"> </td><td class="memItemRight" valign="bottom"><a class="el" href="#a9b6dd221ccd2d939d544004cb6279198">~CommandEncoder</a> ()</td></tr>
|
||||
@ -256,6 +258,23 @@ Public Member Functions</h2></td></tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="ad538ae88f90560063f9ba502e2795991" name="ad538ae88f90560063f9ba502e2795991"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#ad538ae88f90560063f9ba502e2795991">◆ </a></span>maybeInsertBarrier()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">void mlx::core::metal::CommandEncoder::maybeInsertBarrier </td>
|
||||
<td>(</td>
|
||||
<td class="paramname"><span class="paramname"><em></em></span></td><td>)</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="aac45ab0630ea32cf7d15c7ba3e229966" name="aac45ab0630ea32cf7d15c7ba3e229966"></a>
|
||||
|
@ -986,13 +986,13 @@ We will prioritize including it.</p>
|
||||
<span class="n">ys</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">4096</span><span class="p">))</span>
|
||||
|
||||
<span class="k">def</span> <span class="nf">naive_add</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">ys</span><span class="p">):</span>
|
||||
<span class="k">return</span> <span class="p">[</span><span class="n">xs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">ys</span><span class="p">[:,</span> <span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">xs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])]</span>
|
||||
<span class="k">return</span> <span class="p">[</span><span class="n">xs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">ys</span><span class="p">[:,</span> <span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">xs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])]</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Instead you can use <a class="reference internal" href="../python/_autosummary/mlx.core.vmap.html#mlx.core.vmap" title="mlx.core.vmap"><code class="xref py py-func docutils literal notranslate"><span class="pre">vmap()</span></code></a> to automatically vectorize the addition:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># Vectorize over the second dimension of x and the</span>
|
||||
<span class="c1"># first dimension of y</span>
|
||||
<span class="n">vmap_add</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">vmap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
|
||||
<span class="n">vmap_add</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">vmap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>The <code class="docutils literal notranslate"><span class="pre">in_axes</span></code> parameter can be used to specify which dimensions of the
|
||||
|
2
docs/build/html/usage/indexing.html
vendored
2
docs/build/html/usage/indexing.html
vendored
@ -922,7 +922,7 @@ undefined behavior.</p></li>
|
||||
from the GPU. Performing bounds checking for array indices before launching the
|
||||
kernel would be extremely inefficient.</p>
|
||||
<p>Indexing with boolean masks is something that MLX may support in the future. In
|
||||
general, MLX has limited support for operations for which outputs
|
||||
general, MLX has limited support for operations for which output
|
||||
<em>shapes</em> are dependent on input <em>data</em>. Other examples of these types of
|
||||
operations which MLX does not yet support include <a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.nonzero.html#numpy.nonzero" title="(in NumPy v2.1)"><code class="xref py py-func docutils literal notranslate"><span class="pre">numpy.nonzero()</span></code></a> and the
|
||||
single input version of <a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.where.html#numpy.where" title="(in NumPy v2.1)"><code class="xref py py-func docutils literal notranslate"><span class="pre">numpy.where()</span></code></a>.</p>
|
||||
|
2
docs/build/html/usage/lazy_evaluation.html
vendored
2
docs/build/html/usage/lazy_evaluation.html
vendored
@ -952,7 +952,7 @@ stochastic gradient descent). A natural and usually efficient place to use
|
||||
</div>
|
||||
<p>An important behavior to be aware of is when the graph will be implicitly
|
||||
evaluated. Anytime you <code class="docutils literal notranslate"><span class="pre">print</span></code> an array, convert it to an
|
||||
<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v2.1)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">numpy.ndarray</span></code></a>, or otherwise access it’s memory via <a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#memoryview" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">memoryview</span></code></a>,
|
||||
<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v2.1)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">numpy.ndarray</span></code></a>, or otherwise access its memory via <a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#memoryview" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">memoryview</span></code></a>,
|
||||
the graph will be evaluated. Saving arrays via <a class="reference internal" href="../python/_autosummary/mlx.core.save.html#mlx.core.save" title="mlx.core.save"><code class="xref py py-func docutils literal notranslate"><span class="pre">save()</span></code></a> (or any other MLX
|
||||
saving functions) will also evaluate the array.</p>
|
||||
<p>Calling <a class="reference internal" href="../python/_autosummary/mlx.core.array.item.html#mlx.core.array.item" title="mlx.core.array.item"><code class="xref py py-func docutils literal notranslate"><span class="pre">array.item()</span></code></a> on a scalar array will also evaluate it. In the
|
||||
|
Loading…
Reference in New Issue
Block a user