This commit is contained in:
CircleCI Docs
2025-01-09 21:56:20 +00:00
parent 04b749a588
commit d8d647015b
2642 changed files with 137687 additions and 70861 deletions

View File

@@ -3,7 +3,7 @@
<head>
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
<meta http-equiv="X-UA-Compatible" content="IE=11"/>
<meta name="generator" content="Doxygen 1.12.0"/>
<meta name="generator" content="Doxygen 1.13.1"/>
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<title>MLX: mlx/backend/metal/kernels/gemv_masked.h Source File</title>
<link href="tabs.css" rel="stylesheet" type="text/css"/>
@@ -11,11 +11,18 @@
<script type="text/javascript" src="dynsections.js"></script>
<script type="text/javascript" src="clipboard.js"></script>
<link href="navtree.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript" src="navtreedata.js"></script>
<script type="text/javascript" src="navtree.js"></script>
<script type="text/javascript" src="resize.js"></script>
<script type="text/javascript" src="cookie.js"></script>
<link href="search/search.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript" src="search/searchdata.js"></script>
<script type="text/javascript" src="search/search.js"></script>
<script type="text/javascript">
/* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&amp;dn=expat.txt MIT */
$(function() { init_search(); });
/* @license-end */
</script>
<link href="doxygen.css" rel="stylesheet" type="text/css" />
</head>
<body>
@@ -28,12 +35,24 @@
<div id="projectname">MLX
</div>
</td>
<td> <div id="MSearchBox" class="MSearchBoxInactive">
<span class="left">
<span id="MSearchSelect" onmouseover="return searchBox.OnSearchSelectShow()" onmouseout="return searchBox.OnSearchSelectHide()">&#160;</span>
<input type="text" id="MSearchField" value="" placeholder="Search" accesskey="S"
onfocus="searchBox.OnSearchFieldFocus(true)"
onblur="searchBox.OnSearchFieldFocus(false)"
onkeyup="searchBox.OnSearchFieldChange(event)"/>
</span><span class="right">
<a id="MSearchClose" href="javascript:searchBox.CloseResultsWindow()"><img id="MSearchCloseImg" border="0" src="search/close.svg" alt=""/></a>
</span>
</div>
</td>
</tr>
</tbody>
</table>
</div>
<!-- end header part -->
<!-- Generated by Doxygen 1.12.0 -->
<!-- Generated by Doxygen 1.13.1 -->
<script type="text/javascript">
/* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&amp;dn=expat.txt MIT */
var searchBox = new SearchBox("searchBox", "search/",'.html');
@@ -44,22 +63,23 @@ var searchBox = new SearchBox("searchBox", "search/",'.html');
$(function() { codefold.init(0); });
/* @license-end */
</script>
<script type="text/javascript" src="menudata.js"></script>
<script type="text/javascript" src="menu.js"></script>
</div><!-- top -->
<div id="side-nav" class="ui-resizable side-nav-resizable">
<div id="nav-tree">
<div id="nav-tree-contents">
<div id="nav-sync" class="sync"></div>
</div>
</div>
<div id="splitbar" style="-moz-user-select:none;"
class="ui-resizable-handle">
</div>
</div>
<script type="text/javascript">
/* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&amp;dn=expat.txt MIT */
$(function() {
initMenu('',true,false,'search.php','Search',false);
$(function() { init_search(); });
});
/* @license-end */
</script>
<div id="main-nav"></div>
<script type="text/javascript">
/* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&amp;dn=expat.txt MIT */
$(function(){ initResizable(false); });
$(function(){initNavTree('kernels_2gemv__masked_8h_source.html',''); initResizable(true); });
/* @license-end */
</script>
<div id="doc-content">
<!-- window showing the filter options -->
<div id="MSearchSelectWindow"
onmouseover="return searchBox.OnSearchSelectShow()"
@@ -81,12 +101,6 @@ $(function(){ initResizable(false); });
</div>
</div>
<div id="nav-path" class="navpath">
<ul>
<li class="navelem"><a class="el" href="dir_938ab0ecf10b8b860ff766c820f665fd.html">mlx</a></li><li class="navelem"><a class="el" href="dir_1d446c9bd3c99228254c9484e0bc5c06.html">backend</a></li><li class="navelem"><a class="el" href="dir_d0c977ea65824390717cdb7efc36c157.html">metal</a></li><li class="navelem"><a class="el" href="dir_70a37effa88bcbd6b791977fa1e64356.html">kernels</a></li> </ul>
</div>
</div><!-- top -->
<div id="doc-content">
<div class="header">
<div class="headertitle"><div class="title">gemv_masked.h</div></div>
</div><!--header-->
@@ -127,7 +141,7 @@ $(function(){ initResizable(false); });
<div class="line"><a id="l00025" name="l00025"></a><span class="lineno"> 25</span>};</div>
</div>
<div class="line"><a id="l00026" name="l00026"></a><span class="lineno"> 26</span> </div>
<div class="line"><a id="l00027" name="l00027"></a><span class="lineno"><a class="line" href="kernels_2gemv__masked_8h.html#a1480c8cdff1cae1462a5a71632969bca"> 27</a></span><span class="keyword">typedef</span> <span class="keyword">struct </span><a class="code hl_struct" href="struct___no_mask.html">_NoMask</a> <a class="code hl_struct" href="struct___no_mask.html">nomask_t</a>;</div>
<div class="line"><a id="l00027" name="l00027"></a><span class="lineno"><a class="line" href="kernels_2gemv__masked_8h.html#a1480c8cdff1cae1462a5a71632969bca"> 27</a></span><span class="keyword">typedef</span> <span class="keyword">struct </span><a class="code hl_struct" href="struct___no_mask.html">_NoMask</a> <a class="code hl_typedef" href="kernels_2gemv__masked_8h.html#a1480c8cdff1cae1462a5a71632969bca">nomask_t</a>;</div>
<div class="line"><a id="l00028" name="l00028"></a><span class="lineno"> 28</span> </div>
<div class="line"><a id="l00029" name="l00029"></a><span class="lineno"> 29</span><span class="keyword">template</span> &lt;<span class="keyword">typename</span> OutT, <span class="keyword">typename</span> InT = OutT&gt;</div>
<div class="foldopen" id="foldopen00030" data-start="{" data-end="};">
@@ -455,7 +469,7 @@ $(function(){ initResizable(false); });
</div>
<div class="line"><a id="l00342" name="l00342"></a><span class="lineno"> 342</span>};</div>
</div>
<div class="line"><a id="l00343" name="l00343"></a><span class="lineno"> 343</span> </div>
<div class="line"><a id="l00343" name="l00343"></a><span class="lineno"> 343</span></div>
<div class="line"><a id="l00347" name="l00347"></a><span class="lineno"> 347</span> </div>
<div class="line"><a id="l00348" name="l00348"></a><span class="lineno"> 348</span><span class="keyword">template</span> &lt;</div>
<div class="line"><a id="l00349" name="l00349"></a><span class="lineno"> 349</span> <span class="keyword">typename</span> T,</div>
@@ -733,7 +747,7 @@ $(function(){ initResizable(false); });
</div>
<div class="line"><a id="l00619" name="l00619"></a><span class="lineno"> 619</span>};</div>
</div>
<div class="line"><a id="l00620" name="l00620"></a><span class="lineno"> 620</span> </div>
<div class="line"><a id="l00620" name="l00620"></a><span class="lineno"> 620</span></div>
<div class="line"><a id="l00624" name="l00624"></a><span class="lineno"> 624</span> </div>
<div class="line"><a id="l00625" name="l00625"></a><span class="lineno"> 625</span><span class="keyword">template</span> &lt;</div>
<div class="line"><a id="l00626" name="l00626"></a><span class="lineno"> 626</span> <span class="keyword">typename</span> T,</div>
@@ -747,7 +761,7 @@ $(function(){ initResizable(false); });
<div class="line"><a id="l00634" name="l00634"></a><span class="lineno"> 634</span> <span class="keyword">const</span> <span class="keywordtype">int</span> TN, <span class="comment">/* Thread cols (in elements) */</span></div>
<div class="line"><a id="l00635" name="l00635"></a><span class="lineno"> 635</span> <span class="keyword">const</span> <span class="keywordtype">bool</span> kDoNCBatch&gt; <span class="comment">/* Batch ndim &gt; 1 */</span></div>
<div class="foldopen" id="foldopen00636" data-start="{" data-end="}">
<div class="line"><a id="l00636" name="l00636"></a><span class="lineno"><a class="line" href="kernels_2gemv__masked_8h.html#ab3070d14cdecb1dd7dc220a551da6b7b"> 636</a></span>[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] <span class="keywordtype">void</span> <a class="code hl_function" href="kernels_2gemv__masked_8h.html#ab3070d14cdecb1dd7dc220a551da6b7b">gemv_masked</a>(</div>
<div class="line"><a id="l00636" name="l00636"></a><span class="lineno"><a class="line" href="kernels_2gemv__masked_8h.html#af890b6ac155165f8ee0c600363938341"> 636</a></span>[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] <span class="keywordtype">void</span> <a class="code hl_function" href="kernels_2gemv__masked_8h.html#af890b6ac155165f8ee0c600363938341">gemv_masked</a>(</div>
<div class="line"><a id="l00637" name="l00637"></a><span class="lineno"> 637</span> <span class="keyword">const</span> device T* mat [[buffer(0)]],</div>
<div class="line"><a id="l00638" name="l00638"></a><span class="lineno"> 638</span> <span class="keyword">const</span> device T* in_vec [[buffer(1)]],</div>
<div class="line"><a id="l00639" name="l00639"></a><span class="lineno"> 639</span> device T* out_vec [[buffer(3)]],</div>
@@ -756,13 +770,13 @@ $(function(){ initResizable(false); });
<div class="line"><a id="l00642" name="l00642"></a><span class="lineno"> 642</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>&amp; marix_ld [[buffer(6)]],</div>
<div class="line"><a id="l00643" name="l00643"></a><span class="lineno"> 643</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>&amp; batch_ndim [[buffer(9)]],</div>
<div class="line"><a id="l00644" name="l00644"></a><span class="lineno"> 644</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* batch_shape [[buffer(10)]],</div>
<div class="line"><a id="l00645" name="l00645"></a><span class="lineno"> 645</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* vector_batch_stride [[buffer(11)]],</div>
<div class="line"><a id="l00646" name="l00646"></a><span class="lineno"> 646</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* matrix_batch_stride [[buffer(12)]],</div>
<div class="line"><a id="l00645" name="l00645"></a><span class="lineno"> 645</span> <span class="keyword">const</span> constant int64_t* vector_batch_stride [[buffer(11)]],</div>
<div class="line"><a id="l00646" name="l00646"></a><span class="lineno"> 646</span> <span class="keyword">const</span> constant int64_t* matrix_batch_stride [[buffer(12)]],</div>
<div class="line"><a id="l00647" name="l00647"></a><span class="lineno"> 647</span> <span class="keyword">const</span> device out_mask_t* out_mask [[buffer(20)]],</div>
<div class="line"><a id="l00648" name="l00648"></a><span class="lineno"> 648</span> <span class="keyword">const</span> device op_mask_t* mat_mask [[buffer(21)]],</div>
<div class="line"><a id="l00649" name="l00649"></a><span class="lineno"> 649</span> <span class="keyword">const</span> device op_mask_t* vec_mask [[buffer(22)]],</div>
<div class="line"><a id="l00650" name="l00650"></a><span class="lineno"> 650</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* mask_strides [[buffer(23)]],</div>
<div class="line"><a id="l00651" name="l00651"></a><span class="lineno"> 651</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* mask_batch_strides [[buffer(24)]],</div>
<div class="line"><a id="l00651" name="l00651"></a><span class="lineno"> 651</span> <span class="keyword">const</span> constant int64_t* mask_batch_strides [[buffer(24)]],</div>
<div class="line"><a id="l00652" name="l00652"></a><span class="lineno"> 652</span> uint3 tid [[threadgroup_position_in_grid]],</div>
<div class="line"><a id="l00653" name="l00653"></a><span class="lineno"> 653</span> uint3 lid [[thread_position_in_threadgroup]],</div>
<div class="line"><a id="l00654" name="l00654"></a><span class="lineno"> 654</span> uint simd_gid [[simdgroup_index_in_threadgroup]],</div>
@@ -777,20 +791,20 @@ $(function(){ initResizable(false); });
<div class="line"><a id="l00663" name="l00663"></a><span class="lineno"> 663</span> </div>
<div class="line"><a id="l00664" name="l00664"></a><span class="lineno"> 664</span> <span class="comment">// Update batch offsets</span></div>
<div class="line"><a id="l00665" name="l00665"></a><span class="lineno"> 665</span> <span class="keywordflow">if</span> (kDoNCBatch) {</div>
<div class="line"><a id="l00666" name="l00666"></a><span class="lineno"> 666</span> in_vec += <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a22eaa505dbc7dd2a63a895f2e16712f5">elem_to_loc</a>(tid.z, batch_shape, vector_batch_stride, batch_ndim);</div>
<div class="line"><a id="l00667" name="l00667"></a><span class="lineno"> 667</span> mat += <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a22eaa505dbc7dd2a63a895f2e16712f5">elem_to_loc</a>(tid.z, batch_shape, matrix_batch_stride, batch_ndim);</div>
<div class="line"><a id="l00666" name="l00666"></a><span class="lineno"> 666</span> in_vec += <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a497dd9f1a00c8a4303d8782158a0812a">elem_to_loc</a>(tid.z, batch_shape, vector_batch_stride, batch_ndim);</div>
<div class="line"><a id="l00667" name="l00667"></a><span class="lineno"> 667</span> mat += <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a497dd9f1a00c8a4303d8782158a0812a">elem_to_loc</a>(tid.z, batch_shape, matrix_batch_stride, batch_ndim);</div>
<div class="line"><a id="l00668" name="l00668"></a><span class="lineno"> 668</span> </div>
<div class="line"><a id="l00669" name="l00669"></a><span class="lineno"> 669</span> <span class="keywordflow">if</span> (has_output_mask) {</div>
<div class="line"><a id="l00670" name="l00670"></a><span class="lineno"> 670</span> out_mask +=</div>
<div class="line"><a id="l00671" name="l00671"></a><span class="lineno"> 671</span> <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a22eaa505dbc7dd2a63a895f2e16712f5">elem_to_loc</a>(tid.z, batch_shape, mask_batch_strides, batch_ndim);</div>
<div class="line"><a id="l00671" name="l00671"></a><span class="lineno"> 671</span> <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a497dd9f1a00c8a4303d8782158a0812a">elem_to_loc</a>(tid.z, batch_shape, mask_batch_strides, batch_ndim);</div>
<div class="line"><a id="l00672" name="l00672"></a><span class="lineno"> 672</span> mask_batch_strides += batch_ndim;</div>
<div class="line"><a id="l00673" name="l00673"></a><span class="lineno"> 673</span> }</div>
<div class="line"><a id="l00674" name="l00674"></a><span class="lineno"> 674</span> </div>
<div class="line"><a id="l00675" name="l00675"></a><span class="lineno"> 675</span> <span class="keywordflow">if</span> (has_operand_mask) {</div>
<div class="line"><a id="l00676" name="l00676"></a><span class="lineno"> 676</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* mask_strides_mat = mask_batch_strides;</div>
<div class="line"><a id="l00677" name="l00677"></a><span class="lineno"> 677</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* mask_strides_vec = mask_strides_mat + batch_ndim;</div>
<div class="line"><a id="l00676" name="l00676"></a><span class="lineno"> 676</span> <span class="keyword">const</span> constant <span class="keyword">auto</span>* mask_strides_mat = mask_batch_strides;</div>
<div class="line"><a id="l00677" name="l00677"></a><span class="lineno"> 677</span> <span class="keyword">const</span> constant <span class="keyword">auto</span>* mask_strides_vec = mask_strides_mat + batch_ndim;</div>
<div class="line"><a id="l00678" name="l00678"></a><span class="lineno"> 678</span> </div>
<div class="line"><a id="l00679" name="l00679"></a><span class="lineno"> 679</span> ulong2 batch_offsets = <a class="code hl_function" href="backend_2metal_2kernels_2steel_2utils_8h.html#aaf4974425147d6f26d031691e321637f">elem_to_loc_broadcast</a>(</div>
<div class="line"><a id="l00679" name="l00679"></a><span class="lineno"> 679</span> ulong2 batch_offsets = <a class="code hl_function" href="backend_2metal_2kernels_2steel_2utils_8h.html#af62bacceef7d93f8c1ba4fcf5b1adfe6">elem_to_loc_broadcast</a>(</div>
<div class="line"><a id="l00680" name="l00680"></a><span class="lineno"> 680</span> tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);</div>
<div class="line"><a id="l00681" name="l00681"></a><span class="lineno"> 681</span> </div>
<div class="line"><a id="l00682" name="l00682"></a><span class="lineno"> 682</span> mat_mask += batch_offsets.x;</div>
@@ -832,7 +846,7 @@ $(function(){ initResizable(false); });
<div class="line"><a id="l00718" name="l00718"></a><span class="lineno"> 718</span> simd_lid);</div>
<div class="line"><a id="l00719" name="l00719"></a><span class="lineno"> 719</span>}</div>
</div>
<div class="line"><a id="l00720" name="l00720"></a><span class="lineno"> 720</span> </div>
<div class="line"><a id="l00720" name="l00720"></a><span class="lineno"> 720</span></div>
<div class="line"><a id="l00724" name="l00724"></a><span class="lineno"> 724</span> </div>
<div class="line"><a id="l00725" name="l00725"></a><span class="lineno"> 725</span><span class="keyword">template</span> &lt;</div>
<div class="line"><a id="l00726" name="l00726"></a><span class="lineno"> 726</span> <span class="keyword">typename</span> T,</div>
@@ -846,7 +860,7 @@ $(function(){ initResizable(false); });
<div class="line"><a id="l00734" name="l00734"></a><span class="lineno"> 734</span> <span class="keyword">const</span> <span class="keywordtype">int</span> TN, <span class="comment">/* Thread cols (in elements) */</span></div>
<div class="line"><a id="l00735" name="l00735"></a><span class="lineno"> 735</span> <span class="keyword">const</span> <span class="keywordtype">bool</span> kDoNCBatch&gt; <span class="comment">/* Batch ndim &gt; 1 */</span></div>
<div class="foldopen" id="foldopen00736" data-start="{" data-end="}">
<div class="line"><a id="l00736" name="l00736"></a><span class="lineno"><a class="line" href="kernels_2gemv__masked_8h.html#a0c8d353fc453e448b2d0ed9a19431b63"> 736</a></span>[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] <span class="keywordtype">void</span> <a class="code hl_function" href="kernels_2gemv__masked_8h.html#a0c8d353fc453e448b2d0ed9a19431b63">gemv_t_masked</a>(</div>
<div class="line"><a id="l00736" name="l00736"></a><span class="lineno"><a class="line" href="kernels_2gemv__masked_8h.html#ae5b4a5124ddf92a984258a0be1ff0f4f"> 736</a></span>[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] <span class="keywordtype">void</span> <a class="code hl_function" href="kernels_2gemv__masked_8h.html#ae5b4a5124ddf92a984258a0be1ff0f4f">gemv_t_masked</a>(</div>
<div class="line"><a id="l00737" name="l00737"></a><span class="lineno"> 737</span> <span class="keyword">const</span> device T* mat [[buffer(0)]],</div>
<div class="line"><a id="l00738" name="l00738"></a><span class="lineno"> 738</span> <span class="keyword">const</span> device T* in_vec [[buffer(1)]],</div>
<div class="line"><a id="l00739" name="l00739"></a><span class="lineno"> 739</span> device T* out_vec [[buffer(3)]],</div>
@@ -855,13 +869,13 @@ $(function(){ initResizable(false); });
<div class="line"><a id="l00742" name="l00742"></a><span class="lineno"> 742</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>&amp; marix_ld [[buffer(6)]],</div>
<div class="line"><a id="l00743" name="l00743"></a><span class="lineno"> 743</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>&amp; batch_ndim [[buffer(9)]],</div>
<div class="line"><a id="l00744" name="l00744"></a><span class="lineno"> 744</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* batch_shape [[buffer(10)]],</div>
<div class="line"><a id="l00745" name="l00745"></a><span class="lineno"> 745</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* vector_batch_stride [[buffer(11)]],</div>
<div class="line"><a id="l00746" name="l00746"></a><span class="lineno"> 746</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* matrix_batch_stride [[buffer(12)]],</div>
<div class="line"><a id="l00745" name="l00745"></a><span class="lineno"> 745</span> <span class="keyword">const</span> constant int64_t* vector_batch_stride [[buffer(11)]],</div>
<div class="line"><a id="l00746" name="l00746"></a><span class="lineno"> 746</span> <span class="keyword">const</span> constant int64_t* matrix_batch_stride [[buffer(12)]],</div>
<div class="line"><a id="l00747" name="l00747"></a><span class="lineno"> 747</span> <span class="keyword">const</span> device out_mask_t* out_mask [[buffer(20)]],</div>
<div class="line"><a id="l00748" name="l00748"></a><span class="lineno"> 748</span> <span class="keyword">const</span> device op_mask_t* mat_mask [[buffer(21)]],</div>
<div class="line"><a id="l00749" name="l00749"></a><span class="lineno"> 749</span> <span class="keyword">const</span> device op_mask_t* vec_mask [[buffer(22)]],</div>
<div class="line"><a id="l00750" name="l00750"></a><span class="lineno"> 750</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* mask_strides [[buffer(23)]],</div>
<div class="line"><a id="l00751" name="l00751"></a><span class="lineno"> 751</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* mask_batch_strides [[buffer(24)]],</div>
<div class="line"><a id="l00751" name="l00751"></a><span class="lineno"> 751</span> <span class="keyword">const</span> constant int64_t* mask_batch_strides [[buffer(24)]],</div>
<div class="line"><a id="l00752" name="l00752"></a><span class="lineno"> 752</span> uint3 tid [[threadgroup_position_in_grid]],</div>
<div class="line"><a id="l00753" name="l00753"></a><span class="lineno"> 753</span> uint3 lid [[thread_position_in_threadgroup]],</div>
<div class="line"><a id="l00754" name="l00754"></a><span class="lineno"> 754</span> uint simd_gid [[simdgroup_index_in_threadgroup]],</div>
@@ -876,20 +890,20 @@ $(function(){ initResizable(false); });
<div class="line"><a id="l00763" name="l00763"></a><span class="lineno"> 763</span> </div>
<div class="line"><a id="l00764" name="l00764"></a><span class="lineno"> 764</span> <span class="comment">// Update batch offsets</span></div>
<div class="line"><a id="l00765" name="l00765"></a><span class="lineno"> 765</span> <span class="keywordflow">if</span> (kDoNCBatch) {</div>
<div class="line"><a id="l00766" name="l00766"></a><span class="lineno"> 766</span> in_vec += <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a22eaa505dbc7dd2a63a895f2e16712f5">elem_to_loc</a>(tid.z, batch_shape, vector_batch_stride, batch_ndim);</div>
<div class="line"><a id="l00767" name="l00767"></a><span class="lineno"> 767</span> mat += <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a22eaa505dbc7dd2a63a895f2e16712f5">elem_to_loc</a>(tid.z, batch_shape, matrix_batch_stride, batch_ndim);</div>
<div class="line"><a id="l00766" name="l00766"></a><span class="lineno"> 766</span> in_vec += <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a497dd9f1a00c8a4303d8782158a0812a">elem_to_loc</a>(tid.z, batch_shape, vector_batch_stride, batch_ndim);</div>
<div class="line"><a id="l00767" name="l00767"></a><span class="lineno"> 767</span> mat += <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a497dd9f1a00c8a4303d8782158a0812a">elem_to_loc</a>(tid.z, batch_shape, matrix_batch_stride, batch_ndim);</div>
<div class="line"><a id="l00768" name="l00768"></a><span class="lineno"> 768</span> </div>
<div class="line"><a id="l00769" name="l00769"></a><span class="lineno"> 769</span> <span class="keywordflow">if</span> (has_output_mask) {</div>
<div class="line"><a id="l00770" name="l00770"></a><span class="lineno"> 770</span> out_mask +=</div>
<div class="line"><a id="l00771" name="l00771"></a><span class="lineno"> 771</span> <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a22eaa505dbc7dd2a63a895f2e16712f5">elem_to_loc</a>(tid.z, batch_shape, mask_batch_strides, batch_ndim);</div>
<div class="line"><a id="l00771" name="l00771"></a><span class="lineno"> 771</span> <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a497dd9f1a00c8a4303d8782158a0812a">elem_to_loc</a>(tid.z, batch_shape, mask_batch_strides, batch_ndim);</div>
<div class="line"><a id="l00772" name="l00772"></a><span class="lineno"> 772</span> mask_batch_strides += batch_ndim;</div>
<div class="line"><a id="l00773" name="l00773"></a><span class="lineno"> 773</span> }</div>
<div class="line"><a id="l00774" name="l00774"></a><span class="lineno"> 774</span> </div>
<div class="line"><a id="l00775" name="l00775"></a><span class="lineno"> 775</span> <span class="keywordflow">if</span> (has_operand_mask) {</div>
<div class="line"><a id="l00776" name="l00776"></a><span class="lineno"> 776</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* mask_strides_mat = mask_batch_strides;</div>
<div class="line"><a id="l00777" name="l00777"></a><span class="lineno"> 777</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* mask_strides_vec = mask_strides_mat + batch_ndim;</div>
<div class="line"><a id="l00776" name="l00776"></a><span class="lineno"> 776</span> <span class="keyword">const</span> constant <span class="keyword">auto</span>* mask_strides_mat = mask_batch_strides;</div>
<div class="line"><a id="l00777" name="l00777"></a><span class="lineno"> 777</span> <span class="keyword">const</span> constant <span class="keyword">auto</span>* mask_strides_vec = mask_strides_mat + batch_ndim;</div>
<div class="line"><a id="l00778" name="l00778"></a><span class="lineno"> 778</span> </div>
<div class="line"><a id="l00779" name="l00779"></a><span class="lineno"> 779</span> ulong2 batch_offsets = <a class="code hl_function" href="backend_2metal_2kernels_2steel_2utils_8h.html#aaf4974425147d6f26d031691e321637f">elem_to_loc_broadcast</a>(</div>
<div class="line"><a id="l00779" name="l00779"></a><span class="lineno"> 779</span> ulong2 batch_offsets = <a class="code hl_function" href="backend_2metal_2kernels_2steel_2utils_8h.html#af62bacceef7d93f8c1ba4fcf5b1adfe6">elem_to_loc_broadcast</a>(</div>
<div class="line"><a id="l00780" name="l00780"></a><span class="lineno"> 780</span> tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);</div>
<div class="line"><a id="l00781" name="l00781"></a><span class="lineno"> 781</span> </div>
<div class="line"><a id="l00782" name="l00782"></a><span class="lineno"> 782</span> mat_mask += batch_offsets.x;</div>
@@ -932,12 +946,13 @@ $(function(){ initResizable(false); });
<div class="line"><a id="l00819" name="l00819"></a><span class="lineno"> 819</span>}</div>
</div>
<div class="ttc" id="abackend_2metal_2kernels_2steel_2utils_8h_html"><div class="ttname"><a href="backend_2metal_2kernels_2steel_2utils_8h.html">utils.h</a></div></div>
<div class="ttc" id="abackend_2metal_2kernels_2steel_2utils_8h_html_aaf4974425147d6f26d031691e321637f"><div class="ttname"><a href="backend_2metal_2kernels_2steel_2utils_8h.html#aaf4974425147d6f26d031691e321637f">elem_to_loc_broadcast</a></div><div class="ttdeci">METAL_FUNC ulong2 elem_to_loc_broadcast(uint elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)</div><div class="ttdef"><b>Definition</b> utils.h:7</div></div>
<div class="ttc" id="abackend_2metal_2kernels_2utils_8h_html_a22eaa505dbc7dd2a63a895f2e16712f5"><div class="ttname"><a href="backend_2metal_2kernels_2utils_8h.html#a22eaa505dbc7dd2a63a895f2e16712f5">elem_to_loc</a></div><div class="ttdeci">METAL_FUNC IdxT elem_to_loc(uint elem, constant const int *shape, constant const StrideT *strides, int ndim)</div><div class="ttdef"><b>Definition</b> utils.h:93</div></div>
<div class="ttc" id="abackend_2metal_2kernels_2steel_2utils_8h_html_af62bacceef7d93f8c1ba4fcf5b1adfe6"><div class="ttname"><a href="backend_2metal_2kernels_2steel_2utils_8h.html#af62bacceef7d93f8c1ba4fcf5b1adfe6">elem_to_loc_broadcast</a></div><div class="ttdeci">METAL_FUNC ulong2 elem_to_loc_broadcast(uint elem, constant const int *shape, constant const int64_t *a_strides, constant const int64_t *b_strides, int ndim)</div><div class="ttdef"><b>Definition</b> utils.h:7</div></div>
<div class="ttc" id="abackend_2metal_2kernels_2utils_8h_html_a497dd9f1a00c8a4303d8782158a0812a"><div class="ttname"><a href="backend_2metal_2kernels_2utils_8h.html#a497dd9f1a00c8a4303d8782158a0812a">elem_to_loc</a></div><div class="ttdeci">METAL_FUNC IdxT elem_to_loc(IdxT elem, constant const int *shape, constant const int64_t *strides, int ndim)</div><div class="ttdef"><b>Definition</b> utils.h:93</div></div>
<div class="ttc" id="akernels_2gemv__masked_8h_html_a0386011c52d03e60885a31e6fbd903dd"><div class="ttname"><a href="kernels_2gemv__masked_8h.html#a0386011c52d03e60885a31e6fbd903dd">MLX_MTL_CONST</a></div><div class="ttdeci">#define MLX_MTL_CONST</div><div class="ttdef"><b>Definition</b> gemv_masked.h:7</div></div>
<div class="ttc" id="akernels_2gemv__masked_8h_html_a069b682d7d21827461544817d722bfd3"><div class="ttname"><a href="kernels_2gemv__masked_8h.html#a069b682d7d21827461544817d722bfd3">MLX_MTL_PRAGMA_UNROLL</a></div><div class="ttdeci">#define MLX_MTL_PRAGMA_UNROLL</div><div class="ttdef"><b>Definition</b> gemv_masked.h:8</div></div>
<div class="ttc" id="akernels_2gemv__masked_8h_html_a0c8d353fc453e448b2d0ed9a19431b63"><div class="ttname"><a href="kernels_2gemv__masked_8h.html#a0c8d353fc453e448b2d0ed9a19431b63">gemv_t_masked</a></div><div class="ttdeci">void gemv_t_masked(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &amp;in_vec_size, const constant int &amp;out_vec_size, const constant int &amp;marix_ld, const constant int &amp;batch_ndim, const constant int *batch_shape, const constant size_t *vector_batch_stride, const constant size_t *matrix_batch_stride, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, const constant size_t *mask_batch_strides, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)</div><div class="ttdoc">Vector matrix multiplication.</div><div class="ttdef"><b>Definition</b> gemv_masked.h:736</div></div>
<div class="ttc" id="akernels_2gemv__masked_8h_html_ab3070d14cdecb1dd7dc220a551da6b7b"><div class="ttname"><a href="kernels_2gemv__masked_8h.html#ab3070d14cdecb1dd7dc220a551da6b7b">gemv_masked</a></div><div class="ttdeci">void gemv_masked(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &amp;in_vec_size, const constant int &amp;out_vec_size, const constant int &amp;marix_ld, const constant int &amp;batch_ndim, const constant int *batch_shape, const constant size_t *vector_batch_stride, const constant size_t *matrix_batch_stride, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, const constant size_t *mask_batch_strides, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)</div><div class="ttdoc">Matrix vector multiplication.</div><div class="ttdef"><b>Definition</b> gemv_masked.h:636</div></div>
<div class="ttc" id="akernels_2gemv__masked_8h_html_a1480c8cdff1cae1462a5a71632969bca"><div class="ttname"><a href="kernels_2gemv__masked_8h.html#a1480c8cdff1cae1462a5a71632969bca">nomask_t</a></div><div class="ttdeci">struct _NoMask nomask_t</div><div class="ttdef"><b>Definition</b> gemv_masked.h:27</div></div>
<div class="ttc" id="akernels_2gemv__masked_8h_html_ae5b4a5124ddf92a984258a0be1ff0f4f"><div class="ttname"><a href="kernels_2gemv__masked_8h.html#ae5b4a5124ddf92a984258a0be1ff0f4f">gemv_t_masked</a></div><div class="ttdeci">void gemv_t_masked(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &amp;in_vec_size, const constant int &amp;out_vec_size, const constant int &amp;marix_ld, const constant int &amp;batch_ndim, const constant int *batch_shape, const constant int64_t *vector_batch_stride, const constant int64_t *matrix_batch_stride, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, const constant int64_t *mask_batch_strides, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)</div><div class="ttdoc">Vector matrix multiplication.</div><div class="ttdef"><b>Definition</b> gemv_masked.h:736</div></div>
<div class="ttc" id="akernels_2gemv__masked_8h_html_af890b6ac155165f8ee0c600363938341"><div class="ttname"><a href="kernels_2gemv__masked_8h.html#af890b6ac155165f8ee0c600363938341">gemv_masked</a></div><div class="ttdeci">void gemv_masked(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &amp;in_vec_size, const constant int &amp;out_vec_size, const constant int &amp;marix_ld, const constant int &amp;batch_ndim, const constant int *batch_shape, const constant int64_t *vector_batch_stride, const constant int64_t *matrix_batch_stride, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, const constant int64_t *mask_batch_strides, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)</div><div class="ttdoc">Matrix vector multiplication.</div><div class="ttdef"><b>Definition</b> gemv_masked.h:636</div></div>
<div class="ttc" id="anamespacemetal_html"><div class="ttname"><a href="namespacemetal.html">metal</a></div><div class="ttdef"><b>Definition</b> bf16_math.h:226</div></div>
<div class="ttc" id="anamespacemetal_html_af6e2dd7ae087aba6abac4f0350b7611c"><div class="ttname"><a href="namespacemetal.html#af6e2dd7ae087aba6abac4f0350b7611c">metal::simd_shuffle_down</a></div><div class="ttdeci">METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta)</div><div class="ttdef"><b>Definition</b> bf16_math.h:377</div></div>
<div class="ttc" id="astruct___no_mask_html"><div class="ttname"><a href="struct___no_mask.html">_NoMask</a></div><div class="ttdef"><b>Definition</b> gemv_masked.h:10</div></div>
@@ -972,10 +987,13 @@ $(function(){ initResizable(false); });
<div class="ttc" id="astruct_scale_op_html_a02043fac21c68fb8d6863a01f45ede4b"><div class="ttname"><a href="struct_scale_op.html#a02043fac21c68fb8d6863a01f45ede4b">ScaleOp::scale</a></div><div class="ttdeci">OutT scale</div><div class="ttdef"><b>Definition</b> gemv_masked.h:31</div></div>
<div class="ttc" id="astruct_scale_op_html_a69f82bc925843a4e1c14dfe8ad2f3218"><div class="ttname"><a href="struct_scale_op.html#a69f82bc925843a4e1c14dfe8ad2f3218">ScaleOp::apply</a></div><div class="ttdeci">METAL_FUNC OutT apply(InT x) const</div><div class="ttdef"><b>Definition</b> gemv_masked.h:33</div></div>
</div><!-- fragment --></div><!-- contents -->
<!-- start footer part -->
<hr class="footer"/><address class="footer"><small>
Generated by&#160;<a href="https://www.doxygen.org/index.html"><img class="footer" src="doxygen.svg" width="104" height="31" alt="doxygen"/></a> 1.12.0
</small></address>
</div><!-- doc-content -->
<!-- start footer part -->
<div id="nav-path" class="navpath"><!-- id is needed for treeview function! -->
<ul>
<li class="navelem"><a class="el" href="dir_938ab0ecf10b8b860ff766c820f665fd.html">mlx</a></li><li class="navelem"><a class="el" href="dir_1d446c9bd3c99228254c9484e0bc5c06.html">backend</a></li><li class="navelem"><a class="el" href="dir_d0c977ea65824390717cdb7efc36c157.html">metal</a></li><li class="navelem"><a class="el" href="dir_70a37effa88bcbd6b791977fa1e64356.html">kernels</a></li><li class="navelem"><a class="el" href="kernels_2gemv__masked_8h.html">gemv_masked.h</a></li>
<li class="footer">Generated by <a href="https://www.doxygen.org/index.html"><img class="footer" src="doxygen.svg" width="104" height="31" alt="doxygen"/></a> 1.13.1 </li>
</ul>
</div>
</body>
</html>