This commit is contained in:
CircleCI Docs
2025-03-20 22:37:22 +00:00
parent a943912d4c
commit cecec56a99
858 changed files with 18494 additions and 17475 deletions

View File

@@ -114,12 +114,12 @@ $(function(){initNavTree('sdpa__vector_8h.html',''); initResizable(true); });
<table class="memberdecls">
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a id="func-members" name="func-members"></a>
Functions</h2></td></tr>
<tr class="memitem:aa83885125881230b6c4657dd3d0eba18" id="r_aa83885125881230b6c4657dd3d0eba18"><td class="memTemplParams" colspan="2">template&lt;typename T, int D, int V = D&gt; </td></tr>
<tr class="memitem:aa83885125881230b6c4657dd3d0eba18"><td class="memTemplItemLeft" align="right" valign="top">void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#aa83885125881230b6c4657dd3d0eba18">sdpa_vector</a> (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &amp;gqa_factor, const constant int &amp;N, const constant size_t &amp;k_stride, const constant size_t &amp;v_stride, const constant float &amp;scale, const device bool *mask, const constant int &amp;mask_kv_seq_stride, const constant int &amp;mask_q_seq_stride, const constant int &amp;mask_head_stride, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:aa83885125881230b6c4657dd3d0eba18"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:ae2a4a8d17e571578ed529f4d4afe93ac" id="r_ae2a4a8d17e571578ed529f4d4afe93ac"><td class="memTemplParams" colspan="2">template&lt;typename T, int D, int V = D&gt; </td></tr>
<tr class="memitem:ae2a4a8d17e571578ed529f4d4afe93ac"><td class="memTemplItemLeft" align="right" valign="top">void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#ae2a4a8d17e571578ed529f4d4afe93ac">sdpa_vector_2pass_1</a> (const device T *queries, const device T *keys, const device T *values, device float *out, device float *sums, device float *maxs, const constant int &amp;gqa_factor, const constant int &amp;N, const constant size_t &amp;k_stride, const constant size_t &amp;v_stride, const constant float &amp;scale, const device bool *mask, const constant int &amp;mask_kv_seq_stride, const constant int &amp;mask_q_seq_stride, const constant int &amp;mask_head_stride, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:ae2a4a8d17e571578ed529f4d4afe93ac"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a3289383906473a108e6aee1993a72816" id="r_a3289383906473a108e6aee1993a72816"><td class="memTemplParams" colspan="2">template&lt;typename T, int D, int V = D&gt; </td></tr>
<tr class="memitem:a3289383906473a108e6aee1993a72816"><td class="memTemplItemLeft" align="right" valign="top">void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a3289383906473a108e6aee1993a72816">sdpa_vector</a> (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &amp;gqa_factor, const constant int &amp;N, const constant size_t &amp;k_head_stride, const constant size_t &amp;k_seq_stride, const constant size_t &amp;v_head_stride, const constant size_t &amp;v_seq_stride, const constant float &amp;scale, const device bool *mask, const constant int &amp;mask_kv_seq_stride, const constant int &amp;mask_q_seq_stride, const constant int &amp;mask_head_stride, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:a3289383906473a108e6aee1993a72816"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a1cdf4f03898ffe2800519892f7f6e0ad" id="r_a1cdf4f03898ffe2800519892f7f6e0ad"><td class="memTemplParams" colspan="2">template&lt;typename T, int D, int V = D&gt; </td></tr>
<tr class="memitem:a1cdf4f03898ffe2800519892f7f6e0ad"><td class="memTemplItemLeft" align="right" valign="top">void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a1cdf4f03898ffe2800519892f7f6e0ad">sdpa_vector_2pass_1</a> (const device T *queries, const device T *keys, const device T *values, device float *out, device float *sums, device float *maxs, const constant int &amp;gqa_factor, const constant int &amp;N, const constant size_t &amp;k_head_stride, const constant size_t &amp;k_seq_stride, const constant size_t &amp;v_head_stride, const constant size_t &amp;v_seq_stride, const constant float &amp;scale, const device bool *mask, const constant int &amp;mask_kv_seq_stride, const constant int &amp;mask_q_seq_stride, const constant int &amp;mask_head_stride, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:a1cdf4f03898ffe2800519892f7f6e0ad"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:ae1be83816bf9332277dab185aa1b58c2" id="r_ae1be83816bf9332277dab185aa1b58c2"><td class="memTemplParams" colspan="2">template&lt;typename T, int D&gt; </td></tr>
<tr class="memitem:ae1be83816bf9332277dab185aa1b58c2"><td class="memTemplItemLeft" align="right" valign="top">void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#ae1be83816bf9332277dab185aa1b58c2">sdpa_vector_2pass_2</a> (const device float *partials, const device float *sums, const device float *maxs, device T *out, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:ae1be83816bf9332277dab185aa1b58c2"><td class="memSeparator" colspan="2">&#160;</td></tr>
@@ -132,8 +132,8 @@ Variables</h2></td></tr>
<tr class="separator:a0c2c54bcc20cc4783a5040d47fa3ba81"><td class="memSeparator" colspan="2">&#160;</td></tr>
</table>
<h2 class="groupheader">Function Documentation</h2>
<a id="aa83885125881230b6c4657dd3d0eba18" name="aa83885125881230b6c4657dd3d0eba18"></a>
<h2 class="memtitle"><span class="permalink"><a href="#aa83885125881230b6c4657dd3d0eba18">&#9670;&#160;</a></span>sdpa_vector()</h2>
<a id="a3289383906473a108e6aee1993a72816" name="a3289383906473a108e6aee1993a72816"></a>
<h2 class="memtitle"><span class="permalink"><a href="#a3289383906473a108e6aee1993a72816">&#9670;&#160;</a></span>sdpa_vector()</h2>
<div class="memitem">
<div class="memproto">
@@ -173,12 +173,22 @@ template&lt;typename T, int D, int V = D&gt; </div>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>k_stride</em></span>, </td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>k_head_stride</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>v_stride</em></span>, </td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>k_seq_stride</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>v_head_stride</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>v_seq_stride</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
@@ -230,8 +240,8 @@ template&lt;typename T, int D, int V = D&gt; </div>
</div>
</div>
<a id="ae2a4a8d17e571578ed529f4d4afe93ac" name="ae2a4a8d17e571578ed529f4d4afe93ac"></a>
<h2 class="memtitle"><span class="permalink"><a href="#ae2a4a8d17e571578ed529f4d4afe93ac">&#9670;&#160;</a></span>sdpa_vector_2pass_1()</h2>
<a id="a1cdf4f03898ffe2800519892f7f6e0ad" name="a1cdf4f03898ffe2800519892f7f6e0ad"></a>
<h2 class="memtitle"><span class="permalink"><a href="#a1cdf4f03898ffe2800519892f7f6e0ad">&#9670;&#160;</a></span>sdpa_vector_2pass_1()</h2>
<div class="memitem">
<div class="memproto">
@@ -281,12 +291,22 @@ template&lt;typename T, int D, int V = D&gt; </div>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>k_stride</em></span>, </td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>k_head_stride</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>v_stride</em></span>, </td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>k_seq_stride</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>v_head_stride</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>v_seq_stride</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>