This commit is contained in:
CircleCI Docs
2024-11-05 19:54:16 +00:00
parent 3addf172d9
commit 98e590e52d
51 changed files with 2277 additions and 1802 deletions

View File

@@ -534,8 +534,8 @@ Functions</h2></td></tr>
<tr class="separator:a84ebe6275218070f0ea320f126f64e22"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:afb57825bb763050cc9a9d194aa41ac36" id="r_afb57825bb763050cc9a9d194aa41ac36"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState *&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#afb57825bb763050cc9a9d194aa41ac36">get_mb_sort_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const std::string &amp;kernel_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;in, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;idx, int bn, int tn)</td></tr>
<tr class="separator:afb57825bb763050cc9a9d194aa41ac36"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a51c4bb09230348bd0252e22bfdc9bc89" id="r_a51c4bb09230348bd0252e22bfdc9bc89"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState *&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#a51c4bb09230348bd0252e22bfdc9bc89">get_reduce_init_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const std::string &amp;kernel_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out)</td></tr>
<tr class="separator:a51c4bb09230348bd0252e22bfdc9bc89"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a3bd386cb6db09f636963ce66ceaf8647" id="r_a3bd386cb6db09f636963ce66ceaf8647"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState *&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#a3bd386cb6db09f636963ce66ceaf8647">get_reduce_init_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const std::string &amp;kernel_name, const std::string &amp;func_name, const std::string &amp;op_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out)</td></tr>
<tr class="separator:a3bd386cb6db09f636963ce66ceaf8647"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a7aa91fcfe8b9caa42d60a957f11bfe6b" id="r_a7aa91fcfe8b9caa42d60a957f11bfe6b"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState *&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#a7aa91fcfe8b9caa42d60a957f11bfe6b">get_reduce_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const std::string &amp;kernel_name, const std::string &amp;func_name, const std::string &amp;op_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;in, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, int ndim=-1, int bm=-1, int bn=-1)</td></tr>
<tr class="separator:a7aa91fcfe8b9caa42d60a957f11bfe6b"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a84fa8e0aee321a9d614433a0b933103b" id="r_a84fa8e0aee321a9d614433a0b933103b"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState *&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#a84fa8e0aee321a9d614433a0b933103b">get_steel_gemm_fused_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const std::string &amp;kernel_name, const std::string &amp;hash_name, const <a class="el" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">metal::MTLFCList</a> &amp;func_consts, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)</td></tr>
@@ -563,8 +563,8 @@ Functions</h2></td></tr>
<tr class="separator:a227588758ccc9ee869dba147e830bb74"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:ab43a7633794498e1c6775cca829eb886" id="r_ab43a7633794498e1c6775cca829eb886"><td class="memItemLeft" align="right" valign="top">void&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#ab43a7633794498e1c6775cca829eb886">steel_matmul</a> (const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &amp;s, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;a, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;b, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector&lt; <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &gt; &amp;copies, std::vector&lt; int &gt; batch_shape={}, std::vector&lt; size_t &gt; A_batch_stride={}, std::vector&lt; size_t &gt; B_batch_stride={})</td></tr>
<tr class="separator:ab43a7633794498e1c6775cca829eb886"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:af7b7ca7c6aa87558d9f98cee5c7a99a8" id="r_af7b7ca7c6aa87558d9f98cee5c7a99a8"><td class="memItemLeft" align="right" valign="top">void&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#af7b7ca7c6aa87558d9f98cee5c7a99a8">all_reduce_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, const std::string &amp;op_name, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &amp;compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &amp;s, std::vector&lt; <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &gt; &amp;copies)</td></tr>
<tr class="separator:af7b7ca7c6aa87558d9f98cee5c7a99a8"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a3ab0fd997d9a35782106ff083a72e098" id="r_a3ab0fd997d9a35782106ff083a72e098"><td class="memItemLeft" align="right" valign="top">void&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#a3ab0fd997d9a35782106ff083a72e098">all_reduce_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, const std::string &amp;op_name, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &amp;compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &amp;s)</td></tr>
<tr class="separator:a3ab0fd997d9a35782106ff083a72e098"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:ab1eeca8ec6fa31819ee108fa6ed2c41b" id="r_ab1eeca8ec6fa31819ee108fa6ed2c41b"><td class="memItemLeft" align="right" valign="top">void&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#ab1eeca8ec6fa31819ee108fa6ed2c41b">row_reduce_general_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, const std::string &amp;op_name, const <a class="el" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a> &amp;plan, const std::vector&lt; int &gt; &amp;axes, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &amp;compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &amp;s)</td></tr>
<tr class="separator:ab1eeca8ec6fa31819ee108fa6ed2c41b"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:aa0332c64ee9965f05026c30a0b778000" id="r_aa0332c64ee9965f05026c30a0b778000"><td class="memItemLeft" align="right" valign="top">void&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#aa0332c64ee9965f05026c30a0b778000">strided_reduce_general_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, const std::string &amp;op_name, const <a class="el" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a> &amp;plan, const std::vector&lt; int &gt; &amp;axes, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &amp;compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &amp;s)</td></tr>
@@ -2634,8 +2634,8 @@ template&lt;typename... T&gt; </div>
</div>
</div>
<h2 class="groupheader">Function Documentation</h2>
<a id="af7b7ca7c6aa87558d9f98cee5c7a99a8" name="af7b7ca7c6aa87558d9f98cee5c7a99a8"></a>
<h2 class="memtitle"><span class="permalink"><a href="#af7b7ca7c6aa87558d9f98cee5c7a99a8">&#9670;&#160;</a></span>all_reduce_dispatch()</h2>
<a id="a3ab0fd997d9a35782106ff083a72e098" name="a3ab0fd997d9a35782106ff083a72e098"></a>
<h2 class="memtitle"><span class="permalink"><a href="#a3ab0fd997d9a35782106ff083a72e098">&#9670;&#160;</a></span>all_reduce_dispatch()</h2>
<div class="memitem">
<div class="memproto">
@@ -2668,12 +2668,7 @@ template&lt;typename... T&gt; </div>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &amp;</td> <td class="paramname"><span class="paramname"><em>s</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">std::vector&lt; <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &gt; &amp;</td> <td class="paramname"><span class="paramname"><em>copies</em></span>&#160;)</td>
<td class="paramtype">const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &amp;</td> <td class="paramname"><span class="paramname"><em>s</em></span>&#160;)</td>
</tr>
</table>
</div><div class="memdoc">
@@ -4418,8 +4413,8 @@ template&lt;typename... Arrays, typename = enable_for_arrays_t&lt;Arrays...&gt;
</div>
</div>
<a id="a51c4bb09230348bd0252e22bfdc9bc89" name="a51c4bb09230348bd0252e22bfdc9bc89"></a>
<h2 class="memtitle"><span class="permalink"><a href="#a51c4bb09230348bd0252e22bfdc9bc89">&#9670;&#160;</a></span>get_reduce_init_kernel()</h2>
<a id="a3bd386cb6db09f636963ce66ceaf8647" name="a3bd386cb6db09f636963ce66ceaf8647"></a>
<h2 class="memtitle"><span class="permalink"><a href="#a3bd386cb6db09f636963ce66ceaf8647">&#9670;&#160;</a></span>get_reduce_init_kernel()</h2>
<div class="memitem">
<div class="memproto">
@@ -4434,6 +4429,16 @@ template&lt;typename... Arrays, typename = enable_for_arrays_t&lt;Arrays...&gt;
<td></td>
<td class="paramtype">const std::string &amp;</td> <td class="paramname"><span class="paramname"><em>kernel_name</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const std::string &amp;</td> <td class="paramname"><span class="paramname"><em>func_name</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const std::string &amp;</td> <td class="paramname"><span class="paramname"><em>op_name</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>