mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-06 03:58:12 +08:00
docs update
This commit is contained in:
committed by
CircleCI Docs
parent
44a49282c9
commit
3b72c27899
231
docs/build/html/reduce__col_8h.html
vendored
231
docs/build/html/reduce__col_8h.html
vendored
@@ -85,156 +85,25 @@ $(function() {
|
||||
<table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a id="func-members" name="func-members"></a>
|
||||
Functions</h2></td></tr>
|
||||
<tr class="memitem:a7da96ffb66e1fda27bad2852c2285b94" id="r_a7da96ffb66e1fda27bad2852c2285b94"><td class="memTemplParams" colspan="2">template<typename T , typename U , typename Op > </td></tr>
|
||||
<tr class="memitem:a7da96ffb66e1fda27bad2852c2285b94"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a7da96ffb66e1fda27bad2852c2285b94">col_reduce_small</a> (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant size_t &non_col_reductions, const constant int *non_col_shapes, const constant size_t *non_col_strides, const constant int &non_col_ndim, uint tid)</td></tr>
|
||||
<tr class="separator:a7da96ffb66e1fda27bad2852c2285b94"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a6d72f6a88a37d8e031d0ac33f26ecbb4" id="r_a6d72f6a88a37d8e031d0ac33f26ecbb4"><td class="memTemplParams" colspan="2">template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS> </td></tr>
|
||||
<tr class="memitem:a6d72f6a88a37d8e031d0ac33f26ecbb4"><td class="memTemplItemLeft" align="right" valign="top">METAL_FUNC U </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a6d72f6a88a37d8e031d0ac33f26ecbb4">_contiguous_strided_reduce</a> (const device T *in, threadgroup U *local_data, uint in_idx, uint reduction_size, uint reduction_stride, uint2 tid, uint2 lid, uint2 lsize)</td></tr>
|
||||
<tr class="separator:a6d72f6a88a37d8e031d0ac33f26ecbb4"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a3b14333e195c0a07b70069bebf85d5c3" id="r_a3b14333e195c0a07b70069bebf85d5c3"><td class="memTemplParams" colspan="2">template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS> </td></tr>
|
||||
<tr class="memitem:a3b14333e195c0a07b70069bebf85d5c3"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a3b14333e195c0a07b70069bebf85d5c3">col_reduce_general</a> (const device T *in, device <a class="el" href="structmlx__atomic.html">mlx_atomic</a>< U > *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, threadgroup U *local_data, uint3 tid, uint3 lid, uint3 lsize)</td></tr>
|
||||
<tr class="separator:a3b14333e195c0a07b70069bebf85d5c3"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:aef03a4131c710ba8f94a666f58719eb7" id="r_aef03a4131c710ba8f94a666f58719eb7"><td class="memTemplParams" colspan="2">template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS> </td></tr>
|
||||
<tr class="memitem:aef03a4131c710ba8f94a666f58719eb7"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#aef03a4131c710ba8f94a666f58719eb7">col_reduce_general_no_atomics</a> (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant size_t &out_size, const constant int *shape, const constant size_t *strides, const constant int &ndim, threadgroup U *local_data, uint3 tid, uint3 lid, uint3 gid, uint3 lsize, uint3 gsize)</td></tr>
|
||||
<tr class="separator:aef03a4131c710ba8f94a666f58719eb7"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:adf7aeb18cd1d5042cf6d9b46b582d8ce" id="r_adf7aeb18cd1d5042cf6d9b46b582d8ce"><td class="memTemplParams" colspan="2">template<typename T , typename U , typename Op , int NDIMS = 0, int N_READS = REDUCE_N_READS> </td></tr>
|
||||
<tr class="memitem:adf7aeb18cd1d5042cf6d9b46b582d8ce"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#adf7aeb18cd1d5042cf6d9b46b582d8ce">col_reduce_small</a> (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 tsize)</td></tr>
|
||||
<tr class="separator:adf7aeb18cd1d5042cf6d9b46b582d8ce"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a11bfc6112ae2386ac03f5ea7b7d93385" id="r_a11bfc6112ae2386ac03f5ea7b7d93385"><td class="memTemplParams" colspan="2">template<typename T , typename U , typename Op , int NDIMS = 0, int BM = 8, int BN = 128> </td></tr>
|
||||
<tr class="memitem:a11bfc6112ae2386ac03f5ea7b7d93385"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a11bfc6112ae2386ac03f5ea7b7d93385">col_reduce_looped</a> (const device T *in, device U *out, const constant size_t &reduction_size, const constant size_t &reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &reduce_ndim, const constant size_t &non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)</td></tr>
|
||||
<tr class="memdesc:a11bfc6112ae2386ac03f5ea7b7d93385"><td class="mdescLeft"> </td><td class="mdescRight">Our approach is the following simple looped approach: <br /></td></tr>
|
||||
<tr class="separator:a11bfc6112ae2386ac03f5ea7b7d93385"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table>
|
||||
<h2 class="groupheader">Function Documentation</h2>
|
||||
<a id="a6d72f6a88a37d8e031d0ac33f26ecbb4" name="a6d72f6a88a37d8e031d0ac33f26ecbb4"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a6d72f6a88a37d8e031d0ac33f26ecbb4">◆ </a></span>_contiguous_strided_reduce()</h2>
|
||||
<a id="a11bfc6112ae2386ac03f5ea7b7d93385" name="a11bfc6112ae2386ac03f5ea7b7d93385"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a11bfc6112ae2386ac03f5ea7b7d93385">◆ </a></span>col_reduce_looped()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS> </div>
|
||||
template<typename T , typename U , typename Op , int NDIMS = 0, int BM = 8, int BN = 128> </div>
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">METAL_FUNC U _contiguous_strided_reduce </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>in</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">threadgroup U *</td> <td class="paramname"><span class="paramname"><em>local_data</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>in_idx</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>reduction_size</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>reduction_stride</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint2</td> <td class="paramname"><span class="paramname"><em>tid</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint2</td> <td class="paramname"><span class="paramname"><em>lid</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint2</td> <td class="paramname"><span class="paramname"><em>lsize</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a3b14333e195c0a07b70069bebf85d5c3" name="a3b14333e195c0a07b70069bebf85d5c3"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a3b14333e195c0a07b70069bebf85d5c3">◆ </a></span>col_reduce_general()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS> </div>
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">void col_reduce_general </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>in</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">device <a class="el" href="structmlx__atomic.html">mlx_atomic</a>< U > *</td> <td class="paramname"><span class="paramname"><em>out</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>reduction_size</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>reduction_stride</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>out_size</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>shape</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>strides</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>ndim</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">threadgroup U *</td> <td class="paramname"><span class="paramname"><em>local_data</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>tid</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>lid</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>lsize</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="aef03a4131c710ba8f94a666f58719eb7" name="aef03a4131c710ba8f94a666f58719eb7"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#aef03a4131c710ba8f94a666f58719eb7">◆ </a></span>col_reduce_general_no_atomics()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS> </div>
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">void col_reduce_general_no_atomics </td>
|
||||
<td class="memname">void col_reduce_looped </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>in</em>, </span></td>
|
||||
</tr>
|
||||
@@ -253,11 +122,6 @@ template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>reduction_stride</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>out_size</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
@@ -276,17 +140,22 @@ template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">threadgroup U *</td> <td class="paramname"><span class="paramname"><em>local_data</em>, </span></td>
|
||||
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>reduce_shape</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>tid</em>, </span></td>
|
||||
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>reduce_strides</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>lid</em>, </span></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>reduce_ndim</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>non_col_reductions</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
@@ -296,25 +165,39 @@ template<typename T , typename U , typename Op , int N_READS = REDUCE_N_READS
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>lsize</em>, </span></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>gsize</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>gsize</em></span> )</td>
|
||||
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_lane_id</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_group_id</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
<p>Our approach is the following simple looped approach: </p>
|
||||
<ol type="1">
|
||||
<li>Each thread keeps running totals for BN / n_simdgroups outputs.</li>
|
||||
<li>Load a tile BM, BN in registers and accumulate in the running totals</li>
|
||||
<li>Move ahead by BM steps until the column axis and the non column reductions are exhausted.</li>
|
||||
<li>If BM == 32 then transpose in SM and simd reduce the running totals. Otherwise write in shared memory and BN threads accumulate the running totals with a loop.</li>
|
||||
<li>Write them to the output </li>
|
||||
</ol>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a7da96ffb66e1fda27bad2852c2285b94" name="a7da96ffb66e1fda27bad2852c2285b94"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a7da96ffb66e1fda27bad2852c2285b94">◆ </a></span>col_reduce_small()</h2>
|
||||
<a id="adf7aeb18cd1d5042cf6d9b46b582d8ce" name="adf7aeb18cd1d5042cf6d9b46b582d8ce"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#adf7aeb18cd1d5042cf6d9b46b582d8ce">◆ </a></span>col_reduce_small()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<typename T , typename U , typename Op > </div>
|
||||
template<typename T , typename U , typename Op , int NDIMS = 0, int N_READS = REDUCE_N_READS> </div>
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">void col_reduce_small </td>
|
||||
@@ -336,11 +219,6 @@ template<typename T , typename U , typename Op > </div>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>reduction_stride</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>out_size</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
@@ -356,6 +234,21 @@ template<typename T , typename U , typename Op > </div>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>ndim</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>reduce_shape</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>reduce_strides</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>reduce_ndim</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
@@ -364,22 +257,32 @@ template<typename T , typename U , typename Op > </div>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>non_col_shapes</em>, </span></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>gid</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>non_col_strides</em>, </span></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>gsize</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>non_col_ndim</em>, </span></td>
|
||||
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_lane_id</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>tid</em></span> )</td>
|
||||
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_group_id</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>tid</em>, </span></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>tsize</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
Reference in New Issue
Block a user