mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 11:28:12 +08:00
rebase
This commit is contained in:
54
docs/build/html/steel__attention_8h.html
vendored
54
docs/build/html/steel__attention_8h.html
vendored
@@ -131,9 +131,9 @@ Classes</h2></td></tr>
|
||||
</table><table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a id="func-members" name="func-members"></a>
|
||||
Functions</h2></td></tr>
|
||||
<tr class="memitem:a5423b2a414f5e3c14166d568dedfbd33" id="r_a5423b2a414f5e3c14166d568dedfbd33"><td class="memTemplParams" colspan="2">template<typename T, int BQ, int BK, int BD, int WM, int WN, typename AccumType = float> </td></tr>
|
||||
<tr class="memitem:a5423b2a414f5e3c14166d568dedfbd33"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a5423b2a414f5e3c14166d568dedfbd33">attention</a> (const device T *Q, const device T *K, const device T *V, device T *O, const constant <a class="el" href="structmlx_1_1steel_1_1_attn_params.html">AttnParams</a> *params, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)</td></tr>
|
||||
<tr class="separator:a5423b2a414f5e3c14166d568dedfbd33"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a835f90451dd42ce2d52352d5cce0722a" id="r_a835f90451dd42ce2d52352d5cce0722a"><td class="memTemplParams" colspan="2">template<typename T, int BQ, int BK, int BD, int WM, int WN, typename MaskType = float, typename AccumType = float> </td></tr>
|
||||
<tr class="memitem:a835f90451dd42ce2d52352d5cce0722a"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a835f90451dd42ce2d52352d5cce0722a">attention</a> (const device T *Q, const device T *K, const device T *V, device T *O, const constant <a class="el" href="structmlx_1_1steel_1_1_attn_params.html">AttnParams</a> *params, const constant <a class="el" href="structmlx_1_1steel_1_1_attn_mask_params.html">AttnMaskParams</a> *mask_params, const device MaskType *mask, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid)</td></tr>
|
||||
<tr class="separator:a835f90451dd42ce2d52352d5cce0722a"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table><table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a id="var-members" name="var-members"></a>
|
||||
Variables</h2></td></tr>
|
||||
@@ -141,15 +141,19 @@ Variables</h2></td></tr>
|
||||
<tr class="separator:a171fdea1b23976453f5dc5e6b3161982"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a8bdd2cecf97aa5b033152b1d0f0d2416" id="r_a8bdd2cecf97aa5b033152b1d0f0d2416"><td class="memItemLeft" align="right" valign="top">constant bool </td><td class="memItemRight" valign="bottom"><a class="el" href="#a8bdd2cecf97aa5b033152b1d0f0d2416">align_K</a></td></tr>
|
||||
<tr class="separator:a8bdd2cecf97aa5b033152b1d0f0d2416"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a6ed0dd113fe7d471fc0b869b8c028c81" id="r_a6ed0dd113fe7d471fc0b869b8c028c81"><td class="memItemLeft" align="right" valign="top">constant bool </td><td class="memItemRight" valign="bottom"><a class="el" href="#a6ed0dd113fe7d471fc0b869b8c028c81">has_mask</a></td></tr>
|
||||
<tr class="separator:a6ed0dd113fe7d471fc0b869b8c028c81"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:abfa50278ba59a90e0acb7e5d94500741" id="r_abfa50278ba59a90e0acb7e5d94500741"><td class="memItemLeft" align="right" valign="top">constant bool </td><td class="memItemRight" valign="bottom"><a class="el" href="#abfa50278ba59a90e0acb7e5d94500741">do_causal</a></td></tr>
|
||||
<tr class="separator:abfa50278ba59a90e0acb7e5d94500741"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table>
|
||||
<h2 class="groupheader">Function Documentation</h2>
|
||||
<a id="a5423b2a414f5e3c14166d568dedfbd33" name="a5423b2a414f5e3c14166d568dedfbd33"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a5423b2a414f5e3c14166d568dedfbd33">◆ </a></span>attention()</h2>
|
||||
<a id="a835f90451dd42ce2d52352d5cce0722a" name="a835f90451dd42ce2d52352d5cce0722a"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a835f90451dd42ce2d52352d5cce0722a">◆ </a></span>attention()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<typename T, int BQ, int BK, int BD, int WM, int WN, typename AccumType = float> </div>
|
||||
template<typename T, int BQ, int BK, int BD, int WM, int WN, typename MaskType = float, typename AccumType = float> </div>
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">void attention </td>
|
||||
@@ -176,6 +180,16 @@ template<typename T, int BQ, int BK, int BD, int WM, int WN, typename AccumTy
|
||||
<td></td>
|
||||
<td class="paramtype">const constant <a class="el" href="structmlx_1_1steel_1_1_attn_params.html">AttnParams</a> *</td> <td class="paramname"><span class="paramname"><em>params</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant <a class="el" href="structmlx_1_1steel_1_1_attn_mask_params.html">AttnMaskParams</a> *</td> <td class="paramname"><span class="paramname"><em>mask_params</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const device MaskType *</td> <td class="paramname"><span class="paramname"><em>mask</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
@@ -228,6 +242,34 @@ template<typename T, int BQ, int BK, int BD, int WM, int WN, typename AccumTy
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="abfa50278ba59a90e0acb7e5d94500741" name="abfa50278ba59a90e0acb7e5d94500741"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#abfa50278ba59a90e0acb7e5d94500741">◆ </a></span>do_causal</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">constant bool do_causal</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a6ed0dd113fe7d471fc0b869b8c028c81" name="a6ed0dd113fe7d471fc0b869b8c028c81"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a6ed0dd113fe7d471fc0b869b8c028c81">◆ </a></span>has_mask</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">constant bool has_mask</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div><!-- contents -->
|
||||
|
||||
Reference in New Issue
Block a user