mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
rebase
This commit is contained in:
121
docs/build/html/quantized_8h.html
vendored
121
docs/build/html/quantized_8h.html
vendored
@@ -140,9 +140,9 @@ Functions</h2></td></tr>
|
||||
<tr class="memitem:a8e13c7d895624f738d2a6d9893b687fd" id="r_a8e13c7d895624f738d2a6d9893b687fd"><td class="memTemplParams" colspan="2">template<typename T , int group_size, int bits> </td></tr>
|
||||
<tr class="memitem:a8e13c7d895624f738d2a6d9893b687fd"><td class="memTemplItemLeft" align="right" valign="top">METAL_FUNC void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a8e13c7d895624f738d2a6d9893b687fd">qmv_impl</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:a8e13c7d895624f738d2a6d9893b687fd"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a4a8c8db7d5d480733726fd6d1a645e12" id="r_a4a8c8db7d5d480733726fd6d1a645e12"><td class="memTemplParams" colspan="2">template<typename T , const int group_size, const int bits> </td></tr>
|
||||
<tr class="memitem:a4a8c8db7d5d480733726fd6d1a645e12"><td class="memTemplItemLeft" align="right" valign="top">METAL_FUNC void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a4a8c8db7d5d480733726fd6d1a645e12">qvm_impl</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:a4a8c8db7d5d480733726fd6d1a645e12"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a1546533c5b925b2fbb3bec870ec7487a" id="r_a1546533c5b925b2fbb3bec870ec7487a"><td class="memTemplParams" colspan="2">template<typename T , const int group_size, const int bits> </td></tr>
|
||||
<tr class="memitem:a1546533c5b925b2fbb3bec870ec7487a"><td class="memTemplItemLeft" align="right" valign="top">METAL_FUNC void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a1546533c5b925b2fbb3bec870ec7487a">qvm_impl</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const int in_vec_size, const int out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:a1546533c5b925b2fbb3bec870ec7487a"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:af5750a35e8f5462218effba719f7f5b8" id="r_af5750a35e8f5462218effba719f7f5b8"><td class="memTemplParams" colspan="2">template<typename T , const int group_size, const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32> </td></tr>
|
||||
<tr class="memitem:af5750a35e8f5462218effba719f7f5b8"><td class="memTemplItemLeft" align="right" valign="top">METAL_FUNC void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#af5750a35e8f5462218effba719f7f5b8">qmm_t_impl</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:af5750a35e8f5462218effba719f7f5b8"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
@@ -167,6 +167,9 @@ Functions</h2></td></tr>
|
||||
<tr class="memitem:ad84f7d5ab9e32dbbe3ca759ae5d5d5c5" id="r_ad84f7d5ab9e32dbbe3ca759ae5d5d5c5"><td class="memTemplParams" colspan="2">template<typename T , const int group_size, const int bits, bool batched> </td></tr>
|
||||
<tr class="memitem:ad84f7d5ab9e32dbbe3ca759ae5d5d5c5"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#ad84f7d5ab9e32dbbe3ca759ae5d5d5c5">qvm</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:ad84f7d5ab9e32dbbe3ca759ae5d5d5c5"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:ab8243818512d6078d23e6ffb65fd7bb8" id="r_ab8243818512d6078d23e6ffb65fd7bb8"><td class="memTemplParams" colspan="2">template<typename T , const int group_size, const int bits, int split_k = 32> </td></tr>
|
||||
<tr class="memitem:ab8243818512d6078d23e6ffb65fd7bb8"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#ab8243818512d6078d23e6ffb65fd7bb8">qvm_split_k</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &final_block_size, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:ab8243818512d6078d23e6ffb65fd7bb8"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:abe2e3ef0ee4ec2cb61dc5330ad463d10" id="r_abe2e3ef0ee4ec2cb61dc5330ad463d10"><td class="memTemplParams" colspan="2">template<typename T , const int group_size, const int bits, const bool aligned_N, const bool batched, const int BM = 32, const int BK = 32, const int BN = 32> </td></tr>
|
||||
<tr class="memitem:abe2e3ef0ee4ec2cb61dc5330ad463d10"><td class="memTemplItemLeft" align="right" valign="top">void </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#abe2e3ef0ee4ec2cb61dc5330ad463d10">qmm_t</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)</td></tr>
|
||||
<tr class="separator:abe2e3ef0ee4ec2cb61dc5330ad463d10"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
@@ -2485,8 +2488,8 @@ template<typename T , const int group_size, const int bits, bool batched>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a4a8c8db7d5d480733726fd6d1a645e12" name="a4a8c8db7d5d480733726fd6d1a645e12"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a4a8c8db7d5d480733726fd6d1a645e12">◆ </a></span>qvm_impl()</h2>
|
||||
<a id="a1546533c5b925b2fbb3bec870ec7487a" name="a1546533c5b925b2fbb3bec870ec7487a"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a1546533c5b925b2fbb3bec870ec7487a">◆ </a></span>qvm_impl()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
@@ -2518,6 +2521,69 @@ template<typename T , const int group_size, const int bits> </div>
|
||||
<td></td>
|
||||
<td class="paramtype">device T *</td> <td class="paramname"><span class="paramname"><em>y</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const int</td> <td class="paramname"><span class="paramname"><em>in_vec_size</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const int</td> <td class="paramname"><span class="paramname"><em>out_vec_size</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>tid</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_gid</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_lid</em></span> )</td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="ab8243818512d6078d23e6ffb65fd7bb8" name="ab8243818512d6078d23e6ffb65fd7bb8"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#ab8243818512d6078d23e6ffb65fd7bb8">◆ </a></span>qvm_split_k()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<typename T , const int group_size, const int bits, int split_k = 32> </div>
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname">void qvm_split_k </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">const device uint32_t *</td> <td class="paramname"><span class="paramname"><em>w</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>scales</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>biases</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>x</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">device T *</td> <td class="paramname"><span class="paramname"><em>y</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
@@ -2528,6 +2594,51 @@ template<typename T , const int group_size, const int bits> </div>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>out_vec_size</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>x_batch_ndims</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>x_shape</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>x_strides</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>w_batch_ndims</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>w_shape</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>w_strides</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>s_strides</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>b_strides</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">const constant int &</td> <td class="paramname"><span class="paramname"><em>final_block_size</em></span>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
|
Reference in New Issue
Block a user