mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-07 12:48:15 +08:00
rebase
This commit is contained in:
48
docs/build/html/sdpa__vector_8h.html
vendored
48
docs/build/html/sdpa__vector_8h.html
vendored
@@ -114,12 +114,12 @@ $(function(){initNavTree('sdpa__vector_8h.html',''); initResizable(true); });
|
||||
<table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a id="func-members" name="func-members"></a>
|
||||
Functions</h2></td></tr>
|
||||
<tr class="memitem:aa83885125881230b6c4657dd3d0eba18" id="r_aa83885125881230b6c4657dd3d0eba18"><td class="memTemplParams" colspan="2">template<typename T, int D, int V = D> </td></tr>
|
||||
<tr class="memitem:aa83885125881230b6c4657dd3d0eba18"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#aa83885125881230b6c4657dd3d0eba18">sdpa_vector</a> (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, const device bool *mask, const constant int &mask_kv_seq_stride, const constant int &mask_q_seq_stride, const constant int &mask_head_stride, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:aa83885125881230b6c4657dd3d0eba18"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:ae2a4a8d17e571578ed529f4d4afe93ac" id="r_ae2a4a8d17e571578ed529f4d4afe93ac"><td class="memTemplParams" colspan="2">template<typename T, int D, int V = D> </td></tr>
|
||||
<tr class="memitem:ae2a4a8d17e571578ed529f4d4afe93ac"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#ae2a4a8d17e571578ed529f4d4afe93ac">sdpa_vector_2pass_1</a> (const device T *queries, const device T *keys, const device T *values, device float *out, device float *sums, device float *maxs, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, const device bool *mask, const constant int &mask_kv_seq_stride, const constant int &mask_q_seq_stride, const constant int &mask_head_stride, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:ae2a4a8d17e571578ed529f4d4afe93ac"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a3289383906473a108e6aee1993a72816" id="r_a3289383906473a108e6aee1993a72816"><td class="memTemplParams" colspan="2">template<typename T, int D, int V = D> </td></tr>
|
||||
<tr class="memitem:a3289383906473a108e6aee1993a72816"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a3289383906473a108e6aee1993a72816">sdpa_vector</a> (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_head_stride, const constant size_t &k_seq_stride, const constant size_t &v_head_stride, const constant size_t &v_seq_stride, const constant float &scale, const device bool *mask, const constant int &mask_kv_seq_stride, const constant int &mask_q_seq_stride, const constant int &mask_head_stride, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:a3289383906473a108e6aee1993a72816"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a1cdf4f03898ffe2800519892f7f6e0ad" id="r_a1cdf4f03898ffe2800519892f7f6e0ad"><td class="memTemplParams" colspan="2">template<typename T, int D, int V = D> </td></tr>
|
||||
<tr class="memitem:a1cdf4f03898ffe2800519892f7f6e0ad"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a1cdf4f03898ffe2800519892f7f6e0ad">sdpa_vector_2pass_1</a> (const device T *queries, const device T *keys, const device T *values, device float *out, device float *sums, device float *maxs, const constant int &gqa_factor, const constant int &N, const constant size_t &k_head_stride, const constant size_t &k_seq_stride, const constant size_t &v_head_stride, const constant size_t &v_seq_stride, const constant float &scale, const device bool *mask, const constant int &mask_kv_seq_stride, const constant int &mask_q_seq_stride, const constant int &mask_head_stride, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:a1cdf4f03898ffe2800519892f7f6e0ad"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:ae1be83816bf9332277dab185aa1b58c2" id="r_ae1be83816bf9332277dab185aa1b58c2"><td class="memTemplParams" colspan="2">template<typename T, int D> </td></tr>
|
||||
<tr class="memitem:ae1be83816bf9332277dab185aa1b58c2"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#ae1be83816bf9332277dab185aa1b58c2">sdpa_vector_2pass_2</a> (const device float *partials, const device float *sums, const device float *maxs, device T *out, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:ae1be83816bf9332277dab185aa1b58c2"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
@@ -132,8 +132,8 @@ Variables</h2></td></tr>
|
||||
<tr class="separator:a0c2c54bcc20cc4783a5040d47fa3ba81"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table>
|
||||
<h2 class="groupheader">Function Documentation</h2>
|
||||
<a id="aa83885125881230b6c4657dd3d0eba18" name="aa83885125881230b6c4657dd3d0eba18"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#aa83885125881230b6c4657dd3d0eba18">◆ </a></span>sdpa_vector()</h2>
|
||||
<a id="a3289383906473a108e6aee1993a72816" name="a3289383906473a108e6aee1993a72816"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a3289383906473a108e6aee1993a72816">◆ </a></span>sdpa_vector()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
@@ -173,12 +173,22 @@ template<typename T, int D, int V = D> </div>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>k_stride</em></span>, </td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>k_head_stride</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>v_stride</em></span>, </td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>k_seq_stride</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>v_head_stride</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>v_seq_stride</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
@@ -230,8 +240,8 @@ template<typename T, int D, int V = D> </div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="ae2a4a8d17e571578ed529f4d4afe93ac" name="ae2a4a8d17e571578ed529f4d4afe93ac"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#ae2a4a8d17e571578ed529f4d4afe93ac">◆ </a></span>sdpa_vector_2pass_1()</h2>
|
||||
<a id="a1cdf4f03898ffe2800519892f7f6e0ad" name="a1cdf4f03898ffe2800519892f7f6e0ad"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a1cdf4f03898ffe2800519892f7f6e0ad">◆ </a></span>sdpa_vector_2pass_1()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
@@ -281,12 +291,22 @@ template<typename T, int D, int V = D> </div>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>k_stride</em></span>, </td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>k_head_stride</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>v_stride</em></span>, </td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>k_seq_stride</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>v_head_stride</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t &</td> <td class="paramname"><span class="paramname"><em>v_seq_stride</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
|
||||
Reference in New Issue
Block a user