This commit is contained in:
CircleCI Docs
2024-11-05 19:54:16 +00:00
parent 3addf172d9
commit 98e590e52d
51 changed files with 2277 additions and 1802 deletions

View File

@@ -140,9 +140,9 @@ Functions</h2></td></tr>
<tr class="memitem:a8e13c7d895624f738d2a6d9893b687fd" id="r_a8e13c7d895624f738d2a6d9893b687fd"><td class="memTemplParams" colspan="2">template&lt;typename T , int group_size, int bits&gt; </td></tr>
<tr class="memitem:a8e13c7d895624f738d2a6d9893b687fd"><td class="memTemplItemLeft" align="right" valign="top">METAL_FUNC void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a8e13c7d895624f738d2a6d9893b687fd">qmv_impl</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &amp;in_vec_size, const constant int &amp;out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:a8e13c7d895624f738d2a6d9893b687fd"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a4a8c8db7d5d480733726fd6d1a645e12" id="r_a4a8c8db7d5d480733726fd6d1a645e12"><td class="memTemplParams" colspan="2">template&lt;typename T , const int group_size, const int bits&gt; </td></tr>
<tr class="memitem:a4a8c8db7d5d480733726fd6d1a645e12"><td class="memTemplItemLeft" align="right" valign="top">METAL_FUNC void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a4a8c8db7d5d480733726fd6d1a645e12">qvm_impl</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &amp;in_vec_size, const constant int &amp;out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:a4a8c8db7d5d480733726fd6d1a645e12"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a1546533c5b925b2fbb3bec870ec7487a" id="r_a1546533c5b925b2fbb3bec870ec7487a"><td class="memTemplParams" colspan="2">template&lt;typename T , const int group_size, const int bits&gt; </td></tr>
<tr class="memitem:a1546533c5b925b2fbb3bec870ec7487a"><td class="memTemplItemLeft" align="right" valign="top">METAL_FUNC void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a1546533c5b925b2fbb3bec870ec7487a">qvm_impl</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const int in_vec_size, const int out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:a1546533c5b925b2fbb3bec870ec7487a"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:af5750a35e8f5462218effba719f7f5b8" id="r_af5750a35e8f5462218effba719f7f5b8"><td class="memTemplParams" colspan="2">template&lt;typename T , const int group_size, const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32&gt; </td></tr>
<tr class="memitem:af5750a35e8f5462218effba719f7f5b8"><td class="memTemplItemLeft" align="right" valign="top">METAL_FUNC void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#af5750a35e8f5462218effba719f7f5b8">qmm_t_impl</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &amp;K, const constant int &amp;N, const constant int &amp;M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:af5750a35e8f5462218effba719f7f5b8"><td class="memSeparator" colspan="2">&#160;</td></tr>
@@ -167,6 +167,9 @@ Functions</h2></td></tr>
<tr class="memitem:ad84f7d5ab9e32dbbe3ca759ae5d5d5c5" id="r_ad84f7d5ab9e32dbbe3ca759ae5d5d5c5"><td class="memTemplParams" colspan="2">template&lt;typename T , const int group_size, const int bits, bool batched&gt; </td></tr>
<tr class="memitem:ad84f7d5ab9e32dbbe3ca759ae5d5d5c5"><td class="memTemplItemLeft" align="right" valign="top">void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#ad84f7d5ab9e32dbbe3ca759ae5d5d5c5">qvm</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &amp;in_vec_size, const constant int &amp;out_vec_size, const constant int &amp;x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &amp;w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:ad84f7d5ab9e32dbbe3ca759ae5d5d5c5"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:ab8243818512d6078d23e6ffb65fd7bb8" id="r_ab8243818512d6078d23e6ffb65fd7bb8"><td class="memTemplParams" colspan="2">template&lt;typename T , const int group_size, const int bits, int split_k = 32&gt; </td></tr>
<tr class="memitem:ab8243818512d6078d23e6ffb65fd7bb8"><td class="memTemplItemLeft" align="right" valign="top">void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#ab8243818512d6078d23e6ffb65fd7bb8">qvm_split_k</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &amp;in_vec_size, const constant int &amp;out_vec_size, const constant int &amp;x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &amp;w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &amp;final_block_size, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:ab8243818512d6078d23e6ffb65fd7bb8"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:abe2e3ef0ee4ec2cb61dc5330ad463d10" id="r_abe2e3ef0ee4ec2cb61dc5330ad463d10"><td class="memTemplParams" colspan="2">template&lt;typename T , const int group_size, const int bits, const bool aligned_N, const bool batched, const int BM = 32, const int BK = 32, const int BN = 32&gt; </td></tr>
<tr class="memitem:abe2e3ef0ee4ec2cb61dc5330ad463d10"><td class="memTemplItemLeft" align="right" valign="top">void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#abe2e3ef0ee4ec2cb61dc5330ad463d10">qmm_t</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &amp;K, const constant int &amp;N, const constant int &amp;M, const constant int &amp;x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &amp;w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:abe2e3ef0ee4ec2cb61dc5330ad463d10"><td class="memSeparator" colspan="2">&#160;</td></tr>
@@ -2485,8 +2488,8 @@ template&lt;typename T , const int group_size, const int bits, bool batched&gt;
</div>
</div>
<a id="a4a8c8db7d5d480733726fd6d1a645e12" name="a4a8c8db7d5d480733726fd6d1a645e12"></a>
<h2 class="memtitle"><span class="permalink"><a href="#a4a8c8db7d5d480733726fd6d1a645e12">&#9670;&#160;</a></span>qvm_impl()</h2>
<a id="a1546533c5b925b2fbb3bec870ec7487a" name="a1546533c5b925b2fbb3bec870ec7487a"></a>
<h2 class="memtitle"><span class="permalink"><a href="#a1546533c5b925b2fbb3bec870ec7487a">&#9670;&#160;</a></span>qvm_impl()</h2>
<div class="memitem">
<div class="memproto">
@@ -2518,6 +2521,69 @@ template&lt;typename T , const int group_size, const int bits&gt; </div>
<td></td>
<td class="paramtype">device T *</td> <td class="paramname"><span class="paramname"><em>y</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const int</td> <td class="paramname"><span class="paramname"><em>in_vec_size</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const int</td> <td class="paramname"><span class="paramname"><em>out_vec_size</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>tid</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_gid</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_lid</em></span>&#160;)</td>
</tr>
</table>
</div><div class="memdoc">
</div>
</div>
<a id="ab8243818512d6078d23e6ffb65fd7bb8" name="ab8243818512d6078d23e6ffb65fd7bb8"></a>
<h2 class="memtitle"><span class="permalink"><a href="#ab8243818512d6078d23e6ffb65fd7bb8">&#9670;&#160;</a></span>qvm_split_k()</h2>
<div class="memitem">
<div class="memproto">
<div class="memtemplate">
template&lt;typename T , const int group_size, const int bits, int split_k = 32&gt; </div>
<table class="memname">
<tr>
<td class="memname">void qvm_split_k </td>
<td>(</td>
<td class="paramtype">const device uint32_t *</td> <td class="paramname"><span class="paramname"><em>w</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>scales</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>biases</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>x</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">device T *</td> <td class="paramname"><span class="paramname"><em>y</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
@@ -2528,6 +2594,51 @@ template&lt;typename T , const int group_size, const int bits&gt; </div>
<td></td>
<td class="paramtype">const constant int &amp;</td> <td class="paramname"><span class="paramname"><em>out_vec_size</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant int &amp;</td> <td class="paramname"><span class="paramname"><em>x_batch_ndims</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>x_shape</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>x_strides</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant int &amp;</td> <td class="paramname"><span class="paramname"><em>w_batch_ndims</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>w_shape</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>w_strides</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>s_strides</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>b_strides</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant int &amp;</td> <td class="paramname"><span class="paramname"><em>final_block_size</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>