mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-07 12:48:15 +08:00
rebase
This commit is contained in:
75
docs/build/html/sdpa__vector_8h.html
vendored
75
docs/build/html/sdpa__vector_8h.html
vendored
@@ -114,24 +114,26 @@ $(function(){initNavTree('sdpa__vector_8h.html',''); initResizable(true); });
|
||||
<table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a id="func-members" name="func-members"></a>
|
||||
Functions</h2></td></tr>
|
||||
<tr class="memitem:a826f7a3c7ab843abc0842241db3e57b3" id="r_a826f7a3c7ab843abc0842241db3e57b3"><td class="memTemplParams" colspan="2">template<typename T, int D, int V = D> </td></tr>
|
||||
<tr class="memitem:a826f7a3c7ab843abc0842241db3e57b3"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a826f7a3c7ab843abc0842241db3e57b3">sdpa_vector</a> (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, const device bool *mask, const constant int &mask_seq_stride, const constant int &mask_head_stride, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:a826f7a3c7ab843abc0842241db3e57b3"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:aae1a2f23b03e24734805b08ebc5c1a59" id="r_aae1a2f23b03e24734805b08ebc5c1a59"><td class="memTemplParams" colspan="2">template<typename T, int D, int V = D> </td></tr>
|
||||
<tr class="memitem:aae1a2f23b03e24734805b08ebc5c1a59"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#aae1a2f23b03e24734805b08ebc5c1a59">sdpa_vector_2pass_1</a> (const device T *queries, const device T *keys, const device T *values, device float *out, device float *sums, device float *maxs, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, const device bool *mask, const constant int &mask_seq_stride, const constant int &mask_head_stride, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:aae1a2f23b03e24734805b08ebc5c1a59"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a1368cf3618a4e03dbf743b3463205efe" id="r_a1368cf3618a4e03dbf743b3463205efe"><td class="memTemplParams" colspan="2">template<typename T, int D> </td></tr>
|
||||
<tr class="memitem:a1368cf3618a4e03dbf743b3463205efe"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a1368cf3618a4e03dbf743b3463205efe">sdpa_vector_2pass_2</a> (const device float *partials, const device float *sums, const device float *maxs, device T *out, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:a1368cf3618a4e03dbf743b3463205efe"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:aa83885125881230b6c4657dd3d0eba18" id="r_aa83885125881230b6c4657dd3d0eba18"><td class="memTemplParams" colspan="2">template<typename T, int D, int V = D> </td></tr>
|
||||
<tr class="memitem:aa83885125881230b6c4657dd3d0eba18"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#aa83885125881230b6c4657dd3d0eba18">sdpa_vector</a> (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, const device bool *mask, const constant int &mask_kv_seq_stride, const constant int &mask_q_seq_stride, const constant int &mask_head_stride, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:aa83885125881230b6c4657dd3d0eba18"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:ae2a4a8d17e571578ed529f4d4afe93ac" id="r_ae2a4a8d17e571578ed529f4d4afe93ac"><td class="memTemplParams" colspan="2">template<typename T, int D, int V = D> </td></tr>
|
||||
<tr class="memitem:ae2a4a8d17e571578ed529f4d4afe93ac"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#ae2a4a8d17e571578ed529f4d4afe93ac">sdpa_vector_2pass_1</a> (const device T *queries, const device T *keys, const device T *values, device float *out, device float *sums, device float *maxs, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, const device bool *mask, const constant int &mask_kv_seq_stride, const constant int &mask_q_seq_stride, const constant int &mask_head_stride, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:ae2a4a8d17e571578ed529f4d4afe93ac"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:ae1be83816bf9332277dab185aa1b58c2" id="r_ae1be83816bf9332277dab185aa1b58c2"><td class="memTemplParams" colspan="2">template<typename T, int D> </td></tr>
|
||||
<tr class="memitem:ae1be83816bf9332277dab185aa1b58c2"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#ae1be83816bf9332277dab185aa1b58c2">sdpa_vector_2pass_2</a> (const device float *partials, const device float *sums, const device float *maxs, device T *out, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:ae1be83816bf9332277dab185aa1b58c2"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table><table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a id="var-members" name="var-members"></a>
|
||||
Variables</h2></td></tr>
|
||||
<tr class="memitem:a6ed0dd113fe7d471fc0b869b8c028c81" id="r_a6ed0dd113fe7d471fc0b869b8c028c81"><td class="memItemLeft" align="right" valign="top">constant bool </td><td class="memItemRight" valign="bottom"><a class="el" href="#a6ed0dd113fe7d471fc0b869b8c028c81">has_mask</a></td></tr>
|
||||
<tr class="separator:a6ed0dd113fe7d471fc0b869b8c028c81"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a0c2c54bcc20cc4783a5040d47fa3ba81" id="r_a0c2c54bcc20cc4783a5040d47fa3ba81"><td class="memItemLeft" align="right" valign="top">constant bool </td><td class="memItemRight" valign="bottom"><a class="el" href="#a0c2c54bcc20cc4783a5040d47fa3ba81">query_transposed</a></td></tr>
|
||||
<tr class="separator:a0c2c54bcc20cc4783a5040d47fa3ba81"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table>
|
||||
<h2 class="groupheader">Function Documentation</h2>
|
||||
<a id="a826f7a3c7ab843abc0842241db3e57b3" name="a826f7a3c7ab843abc0842241db3e57b3"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a826f7a3c7ab843abc0842241db3e57b3">◆ </a></span>sdpa_vector()</h2>
|
||||
<a id="aa83885125881230b6c4657dd3d0eba18" name="aa83885125881230b6c4657dd3d0eba18"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#aa83885125881230b6c4657dd3d0eba18">◆ </a></span>sdpa_vector()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
@@ -191,7 +193,12 @@ template<typename T, int D, int V = D> </div>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>mask_seq_stride</em></span>, </td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>mask_kv_seq_stride</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>mask_q_seq_stride</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
@@ -203,6 +210,11 @@ template<typename T, int D, int V = D> </div>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>tid</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>tpg</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
@@ -218,8 +230,8 @@ template<typename T, int D, int V = D> </div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="aae1a2f23b03e24734805b08ebc5c1a59" name="aae1a2f23b03e24734805b08ebc5c1a59"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#aae1a2f23b03e24734805b08ebc5c1a59">◆ </a></span>sdpa_vector_2pass_1()</h2>
|
||||
<a id="ae2a4a8d17e571578ed529f4d4afe93ac" name="ae2a4a8d17e571578ed529f4d4afe93ac"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#ae2a4a8d17e571578ed529f4d4afe93ac">◆ </a></span>sdpa_vector_2pass_1()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
@@ -289,7 +301,12 @@ template<typename T, int D, int V = D> </div>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>mask_seq_stride</em></span>, </td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>mask_kv_seq_stride</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>mask_q_seq_stride</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
@@ -301,6 +318,11 @@ template<typename T, int D, int V = D> </div>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>tid</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>tpg</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
@@ -316,8 +338,8 @@ template<typename T, int D, int V = D> </div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a1368cf3618a4e03dbf743b3463205efe" name="a1368cf3618a4e03dbf743b3463205efe"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a1368cf3618a4e03dbf743b3463205efe">◆ </a></span>sdpa_vector_2pass_2()</h2>
|
||||
<a id="ae1be83816bf9332277dab185aa1b58c2" name="ae1be83816bf9332277dab185aa1b58c2"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#ae1be83816bf9332277dab185aa1b58c2">◆ </a></span>sdpa_vector_2pass_2()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
@@ -349,6 +371,11 @@ template<typename T, int D> </div>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>tid</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>tpg</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
@@ -377,6 +404,20 @@ template<typename T, int D> </div>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a0c2c54bcc20cc4783a5040d47fa3ba81" name="a0c2c54bcc20cc4783a5040d47fa3ba81"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a0c2c54bcc20cc4783a5040d47fa3ba81">◆ </a></span>query_transposed</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">constant bool query_transposed</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div><!-- contents -->
|
||||
|
||||
Reference in New Issue
Block a user