2024-10-15 23:12:17 +08:00
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "https://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
< html xmlns = "http://www.w3.org/1999/xhtml" lang = "en-US" >
< head >
< meta http-equiv = "Content-Type" content = "text/xhtml;charset=UTF-8" / >
< meta http-equiv = "X-UA-Compatible" content = "IE=11" / >
< meta name = "generator" content = "Doxygen 1.12.0" / >
< meta name = "viewport" content = "width=device-width, initial-scale=1" / >
< title > MLX: mlx/backend/metal/kernels/reduction/reduce_row.h File Reference< / title >
< link href = "tabs.css" rel = "stylesheet" type = "text/css" / >
< script type = "text/javascript" src = "jquery.js" > < / script >
< script type = "text/javascript" src = "dynsections.js" > < / script >
< script type = "text/javascript" src = "clipboard.js" > < / script >
< link href = "navtree.css" rel = "stylesheet" type = "text/css" / >
< script type = "text/javascript" src = "resize.js" > < / script >
< script type = "text/javascript" src = "cookie.js" > < / script >
< link href = "search/search.css" rel = "stylesheet" type = "text/css" / >
< script type = "text/javascript" src = "search/searchdata.js" > < / script >
< script type = "text/javascript" src = "search/search.js" > < / script >
< link href = "doxygen.css" rel = "stylesheet" type = "text/css" / >
< / head >
< body >
< div id = "top" > <!-- do not remove this div, it is closed by doxygen! -->
< div id = "titlearea" >
< table cellspacing = "0" cellpadding = "0" >
< tbody >
< tr id = "projectrow" >
< td id = "projectalign" >
< div id = "projectname" > MLX
< / div >
< / td >
< / tr >
< / tbody >
< / table >
< / div >
<!-- end header part -->
<!-- Generated by Doxygen 1.12.0 -->
< script type = "text/javascript" >
/* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699& dn=expat.txt MIT */
var searchBox = new SearchBox("searchBox", "search/",'.html');
/* @license-end */
< / script >
< script type = "text/javascript" >
/* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699& dn=expat.txt MIT */
$(function() { codefold.init(0); });
/* @license-end */
< / script >
< script type = "text/javascript" src = "menudata.js" > < / script >
< script type = "text/javascript" src = "menu.js" > < / script >
< script type = "text/javascript" >
/* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699& dn=expat.txt MIT */
$(function() {
initMenu('',true,false,'search.php','Search',false);
$(function() { init_search(); });
});
/* @license-end */
< / script >
< div id = "main-nav" > < / div >
< script type = "text/javascript" >
/* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699& dn=expat.txt MIT */
$(function(){ initResizable(false); });
/* @license-end */
< / script >
<!-- window showing the filter options -->
< div id = "MSearchSelectWindow"
onmouseover="return searchBox.OnSearchSelectShow()"
onmouseout="return searchBox.OnSearchSelectHide()"
onkeydown="return searchBox.OnSearchSelectKey(event)">
< / div >
<!-- iframe showing the search results (closed by default) -->
< div id = "MSearchResultsWindow" >
< div id = "MSearchResults" >
< div class = "SRPage" >
< div id = "SRIndex" >
< div id = "SRResults" > < / div >
< div class = "SRStatus" id = "Loading" > Loading...< / div >
< div class = "SRStatus" id = "Searching" > Searching...< / div >
< div class = "SRStatus" id = "NoMatches" > No Matches< / div >
< / div >
< / div >
< / div >
< / div >
< div id = "nav-path" class = "navpath" >
< ul >
< li class = "navelem" > < a class = "el" href = "dir_938ab0ecf10b8b860ff766c820f665fd.html" > mlx< / a > < / li > < li class = "navelem" > < a class = "el" href = "dir_1d446c9bd3c99228254c9484e0bc5c06.html" > backend< / a > < / li > < li class = "navelem" > < a class = "el" href = "dir_d0c977ea65824390717cdb7efc36c157.html" > metal< / a > < / li > < li class = "navelem" > < a class = "el" href = "dir_70a37effa88bcbd6b791977fa1e64356.html" > kernels< / a > < / li > < li class = "navelem" > < a class = "el" href = "dir_f60cd69d27fd3faa641c79056fff0e2d.html" > reduction< / a > < / li > < / ul >
< / div >
< / div > <!-- top -->
< div id = "doc-content" >
< div class = "header" >
< div class = "summary" >
< a href = "#func-members" > Functions< / a > < / div >
< div class = "headertitle" > < div class = "title" > reduce_row.h File Reference< / div > < / div >
< / div > <!-- header -->
< div class = "contents" >
< p > < a href = "reduce__row_8h_source.html" > Go to the source code of this file.< / a > < / p >
< table class = "memberdecls" >
< tr class = "heading" > < td colspan = "2" > < h2 class = "groupheader" > < a id = "func-members" name = "func-members" > < / a >
Functions< / h2 > < / td > < / tr >
< tr class = "memitem:a9d5e0049a2276f43702fc6907e74a35f" id = "r_a9d5e0049a2276f43702fc6907e74a35f" > < td class = "memTemplParams" colspan = "2" > template< typename T , typename U , typename Op , int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> < / td > < / tr >
< tr class = "memitem:a9d5e0049a2276f43702fc6907e74a35f" > < td class = "memTemplItemLeft" align = "right" valign = "top" > METAL_FUNC void  < / td > < td class = "memTemplItemRight" valign = "bottom" > < a class = "el" href = "#a9d5e0049a2276f43702fc6907e74a35f" > per_thread_row_reduce< / a > (thread U totals[N_WRITES], const device T *inputs[N_WRITES], int blocks, int extra, uint lsize_x, uint lid_x)< / td > < / tr >
< tr class = "memdesc:a9d5e0049a2276f43702fc6907e74a35f" > < td class = "mdescLeft" >   < / td > < td class = "mdescRight" > The thread group collaboratively reduces across the rows with bounds checking. < br / > < / td > < / tr >
< tr class = "separator:a9d5e0049a2276f43702fc6907e74a35f" > < td class = "memSeparator" colspan = "2" >   < / td > < / tr >
< tr class = "memitem:a045ec34228e77c79ec67d11c39ff097a" id = "r_a045ec34228e77c79ec67d11c39ff097a" > < td class = "memTemplParams" colspan = "2" > template< typename T , typename U , typename Op , int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> < / td > < / tr >
< tr class = "memitem:a045ec34228e77c79ec67d11c39ff097a" > < td class = "memTemplItemLeft" align = "right" valign = "top" > METAL_FUNC void  < / td > < td class = "memTemplItemRight" valign = "bottom" > < a class = "el" href = "#a045ec34228e77c79ec67d11c39ff097a" > per_thread_row_reduce< / a > (thread U totals[N_WRITES], const device T *in, const constant size_t & reduction_size, int blocks, int extra, uint lsize_x, uint lid_x)< / td > < / tr >
< tr class = "memdesc:a045ec34228e77c79ec67d11c39ff097a" > < td class = "mdescLeft" >   < / td > < td class = "mdescRight" > Consecutive rows in a contiguous array. < br / > < / td > < / tr >
< tr class = "separator:a045ec34228e77c79ec67d11c39ff097a" > < td class = "memSeparator" colspan = "2" >   < / td > < / tr >
< tr class = "memitem:a4d00c44e5f4a13be529ff8b664a0a342" id = "r_a4d00c44e5f4a13be529ff8b664a0a342" > < td class = "memTemplParams" colspan = "2" > template< typename T , typename U , typename Op , int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> < / td > < / tr >
< tr class = "memitem:a4d00c44e5f4a13be529ff8b664a0a342" > < td class = "memTemplItemLeft" align = "right" valign = "top" > METAL_FUNC void  < / td > < td class = "memTemplItemRight" valign = "bottom" > < a class = "el" href = "#a4d00c44e5f4a13be529ff8b664a0a342" > per_thread_row_reduce< / a > (thread U totals[N_WRITES], const device T *in, const size_t row_idx, int blocks, int extra, const constant int *shape, const constant size_t *strides, const constant int & ndim, uint lsize_x, uint lid_x)< / td > < / tr >
< tr class = "memdesc:a4d00c44e5f4a13be529ff8b664a0a342" > < td class = "mdescLeft" >   < / td > < td class = "mdescRight" > Consecutive rows in an arbitrarily ordered array. < br / > < / td > < / tr >
< tr class = "separator:a4d00c44e5f4a13be529ff8b664a0a342" > < td class = "memSeparator" colspan = "2" >   < / td > < / tr >
< tr class = "memitem:aa146bb611069fd2892f03714fd1cc3cf" id = "r_aa146bb611069fd2892f03714fd1cc3cf" > < td class = "memTemplParams" colspan = "2" > template< typename T , typename U , typename Op , int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> < / td > < / tr >
< tr class = "memitem:aa146bb611069fd2892f03714fd1cc3cf" > < td class = "memTemplItemLeft" align = "right" valign = "top" > METAL_FUNC void  < / td > < td class = "memTemplItemRight" valign = "bottom" > < a class = "el" href = "#aa146bb611069fd2892f03714fd1cc3cf" > threadgroup_reduce< / a > (thread U totals[N_WRITES], threadgroup U *shared_vals, uint3 lid, uint simd_lane_id, uint simd_per_group, uint simd_group_id)< / td > < / tr >
< tr class = "memdesc:aa146bb611069fd2892f03714fd1cc3cf" > < td class = "mdescLeft" >   < / td > < td class = "mdescRight" > Reduce within the threadgroup. < br / > < / td > < / tr >
< tr class = "separator:aa146bb611069fd2892f03714fd1cc3cf" > < td class = "memSeparator" colspan = "2" >   < / td > < / tr >
< tr class = "memitem:afd80a25fa84e6cc884dcc8698859ade1" id = "r_afd80a25fa84e6cc884dcc8698859ade1" > < td class = "memTemplParams" colspan = "2" > template< typename T , typename U , typename Op , int N_READS = REDUCE_N_READS> < / td > < / tr >
< tr class = "memitem:afd80a25fa84e6cc884dcc8698859ade1" > < td class = "memTemplItemLeft" align = "right" valign = "top" > METAL_FUNC void  < / td > < td class = "memTemplItemRight" valign = "bottom" > < a class = "el" href = "#afd80a25fa84e6cc884dcc8698859ade1" > thread_reduce< / a > (thread U & total, const device T *row, int blocks, int extra)< / td > < / tr >
< tr class = "separator:afd80a25fa84e6cc884dcc8698859ade1" > < td class = "memSeparator" colspan = "2" >   < / td > < / tr >
2024-11-23 04:24:16 +08:00
< tr class = "memitem:aeb49e89f1163cb3093770bb710df9f5e" id = "r_aeb49e89f1163cb3093770bb710df9f5e" > < td class = "memTemplParams" colspan = "2" > template< typename T , typename U , typename Op , typename IdxT , int NDIMS, int N_READS = REDUCE_N_READS> < / td > < / tr >
< tr class = "memitem:aeb49e89f1163cb3093770bb710df9f5e" > < td class = "memTemplItemLeft" align = "right" valign = "top" > void  < / td > < td class = "memTemplItemRight" valign = "bottom" > < a class = "el" href = "#aeb49e89f1163cb3093770bb710df9f5e" > row_reduce_small< / a > (const device T *in, device U *out, const constant size_t & row_size, const constant size_t & non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int & ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int & reduce_ndim, uint simd_lane_id, uint3 gid, uint3 gsize, uint3 tid, uint3 tsize)< / td > < / tr >
< tr class = "separator:aeb49e89f1163cb3093770bb710df9f5e" > < td class = "memSeparator" colspan = "2" >   < / td > < / tr >
< tr class = "memitem:aef628dfccdb1361da5546f8b17c510bf" id = "r_aef628dfccdb1361da5546f8b17c510bf" > < td class = "memTemplParams" colspan = "2" > template< typename T , typename U , typename Op , typename IdxT = size_t, int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> < / td > < / tr >
< tr class = "memitem:aef628dfccdb1361da5546f8b17c510bf" > < td class = "memTemplItemLeft" align = "right" valign = "top" > void  < / td > < td class = "memTemplItemRight" valign = "bottom" > < a class = "el" href = "#aef628dfccdb1361da5546f8b17c510bf" > row_reduce_simple< / a > (const device T *in, device U *out, const constant size_t & reduction_size, const constant size_t & out_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize, uint simd_lane_id, uint simd_per_group, uint simd_group_id)< / td > < / tr >
< tr class = "separator:aef628dfccdb1361da5546f8b17c510bf" > < td class = "memSeparator" colspan = "2" >   < / td > < / tr >
< tr class = "memitem:afba85f5a1c935c124ef52e986d4b2c49" id = "r_afba85f5a1c935c124ef52e986d4b2c49" > < td class = "memTemplParams" colspan = "2" > template< typename T , typename U , typename Op , typename IdxT , int NDIMS, int N_READS = REDUCE_N_READS> < / td > < / tr >
< tr class = "memitem:afba85f5a1c935c124ef52e986d4b2c49" > < td class = "memTemplItemLeft" align = "right" valign = "top" > void  < / td > < td class = "memTemplItemRight" valign = "bottom" > < a class = "el" href = "#afba85f5a1c935c124ef52e986d4b2c49" > row_reduce_looped< / a > (const device T *in, device U *out, const constant size_t & row_size, const constant size_t & non_row_reductions, const constant int *shape, const constant size_t *strides, const constant int & ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int & reduce_ndim, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize, uint simd_lane_id, uint simd_per_group, uint simd_group_id)< / td > < / tr >
< tr class = "separator:afba85f5a1c935c124ef52e986d4b2c49" > < td class = "memSeparator" colspan = "2" >   < / td > < / tr >
2024-10-15 23:12:17 +08:00
< / table >
< h2 class = "groupheader" > Function Documentation< / h2 >
< a id = "a045ec34228e77c79ec67d11c39ff097a" name = "a045ec34228e77c79ec67d11c39ff097a" > < / a >
< h2 class = "memtitle" > < span class = "permalink" > < a href = "#a045ec34228e77c79ec67d11c39ff097a" > ◆   < / a > < / span > per_thread_row_reduce() < span class = "overload" > [1/3]< / span > < / h2 >
< div class = "memitem" >
< div class = "memproto" >
< div class = "memtemplate" >
template< typename T , typename U , typename Op , int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> < / div >
< table class = "memname" >
< tr >
< td class = "memname" > METAL_FUNC void per_thread_row_reduce < / td >
< td > (< / td >
< td class = "paramtype" > thread U< / td > < td class = "paramname" > < span class = "paramname" > < em > totals< / em > < / span > [N_WRITES], < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const device T *< / td > < td class = "paramname" > < span class = "paramname" > < em > in< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant size_t & < / td > < td class = "paramname" > < span class = "paramname" > < em > reduction_size< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > int< / td > < td class = "paramname" > < span class = "paramname" > < em > blocks< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > int< / td > < td class = "paramname" > < span class = "paramname" > < em > extra< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint< / td > < td class = "paramname" > < span class = "paramname" > < em > lsize_x< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint< / td > < td class = "paramname" > < span class = "paramname" > < em > lid_x< / em > < / span >   )< / td >
< / tr >
< / table >
< / div > < div class = "memdoc" >
< p > Consecutive rows in a contiguous array. < / p >
< / div >
< / div >
< a id = "a4d00c44e5f4a13be529ff8b664a0a342" name = "a4d00c44e5f4a13be529ff8b664a0a342" > < / a >
< h2 class = "memtitle" > < span class = "permalink" > < a href = "#a4d00c44e5f4a13be529ff8b664a0a342" > ◆   < / a > < / span > per_thread_row_reduce() < span class = "overload" > [2/3]< / span > < / h2 >
< div class = "memitem" >
< div class = "memproto" >
< div class = "memtemplate" >
template< typename T , typename U , typename Op , int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> < / div >
< table class = "memname" >
< tr >
< td class = "memname" > METAL_FUNC void per_thread_row_reduce < / td >
< td > (< / td >
< td class = "paramtype" > thread U< / td > < td class = "paramname" > < span class = "paramname" > < em > totals< / em > < / span > [N_WRITES], < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const device T *< / td > < td class = "paramname" > < span class = "paramname" > < em > in< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const size_t< / td > < td class = "paramname" > < span class = "paramname" > < em > row_idx< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > int< / td > < td class = "paramname" > < span class = "paramname" > < em > blocks< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > int< / td > < td class = "paramname" > < span class = "paramname" > < em > extra< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant int *< / td > < td class = "paramname" > < span class = "paramname" > < em > shape< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant size_t *< / td > < td class = "paramname" > < span class = "paramname" > < em > strides< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant int & < / td > < td class = "paramname" > < span class = "paramname" > < em > ndim< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint< / td > < td class = "paramname" > < span class = "paramname" > < em > lsize_x< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint< / td > < td class = "paramname" > < span class = "paramname" > < em > lid_x< / em > < / span >   )< / td >
< / tr >
< / table >
< / div > < div class = "memdoc" >
< p > Consecutive rows in an arbitrarily ordered array. < / p >
< / div >
< / div >
< a id = "a9d5e0049a2276f43702fc6907e74a35f" name = "a9d5e0049a2276f43702fc6907e74a35f" > < / a >
< h2 class = "memtitle" > < span class = "permalink" > < a href = "#a9d5e0049a2276f43702fc6907e74a35f" > ◆   < / a > < / span > per_thread_row_reduce() < span class = "overload" > [3/3]< / span > < / h2 >
< div class = "memitem" >
< div class = "memproto" >
< div class = "memtemplate" >
template< typename T , typename U , typename Op , int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> < / div >
< table class = "memname" >
< tr >
< td class = "memname" > METAL_FUNC void per_thread_row_reduce < / td >
< td > (< / td >
< td class = "paramtype" > thread U< / td > < td class = "paramname" > < span class = "paramname" > < em > totals< / em > < / span > [N_WRITES], < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const device T *< / td > < td class = "paramname" > < span class = "paramname" > < em > inputs< / em > < / span > [N_WRITES], < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > int< / td > < td class = "paramname" > < span class = "paramname" > < em > blocks< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > int< / td > < td class = "paramname" > < span class = "paramname" > < em > extra< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint< / td > < td class = "paramname" > < span class = "paramname" > < em > lsize_x< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint< / td > < td class = "paramname" > < span class = "paramname" > < em > lid_x< / em > < / span >   )< / td >
< / tr >
< / table >
< / div > < div class = "memdoc" >
< p > The thread group collaboratively reduces across the rows with bounds checking. < / p >
< p > In the end each thread holds a part of the reduction. < / p >
< / div >
< / div >
2024-11-23 04:24:16 +08:00
< a id = "afba85f5a1c935c124ef52e986d4b2c49" name = "afba85f5a1c935c124ef52e986d4b2c49" > < / a >
< h2 class = "memtitle" > < span class = "permalink" > < a href = "#afba85f5a1c935c124ef52e986d4b2c49" > ◆   < / a > < / span > row_reduce_looped()< / h2 >
2024-10-15 23:12:17 +08:00
< div class = "memitem" >
< div class = "memproto" >
< div class = "memtemplate" >
2024-11-23 04:24:16 +08:00
template< typename T , typename U , typename Op , typename IdxT , int NDIMS, int N_READS = REDUCE_N_READS> < / div >
2024-10-15 23:12:17 +08:00
< table class = "memname" >
< tr >
< td class = "memname" > void row_reduce_looped < / td >
< td > (< / td >
< td class = "paramtype" > const device T *< / td > < td class = "paramname" > < span class = "paramname" > < em > in< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > device U *< / td > < td class = "paramname" > < span class = "paramname" > < em > out< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant size_t & < / td > < td class = "paramname" > < span class = "paramname" > < em > row_size< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant size_t & < / td > < td class = "paramname" > < span class = "paramname" > < em > non_row_reductions< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant int *< / td > < td class = "paramname" > < span class = "paramname" > < em > shape< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant size_t *< / td > < td class = "paramname" > < span class = "paramname" > < em > strides< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant int & < / td > < td class = "paramname" > < span class = "paramname" > < em > ndim< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant int *< / td > < td class = "paramname" > < span class = "paramname" > < em > reduce_shape< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant size_t *< / td > < td class = "paramname" > < span class = "paramname" > < em > reduce_strides< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant int & < / td > < td class = "paramname" > < span class = "paramname" > < em > reduce_ndim< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint3< / td > < td class = "paramname" > < span class = "paramname" > < em > gid< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint3< / td > < td class = "paramname" > < span class = "paramname" > < em > gsize< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint3< / td > < td class = "paramname" > < span class = "paramname" > < em > lid< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint3< / td > < td class = "paramname" > < span class = "paramname" > < em > lsize< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint< / td > < td class = "paramname" > < span class = "paramname" > < em > simd_lane_id< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint< / td > < td class = "paramname" > < span class = "paramname" > < em > simd_per_group< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint< / td > < td class = "paramname" > < span class = "paramname" > < em > simd_group_id< / em > < / span >   )< / td >
< / tr >
< / table >
< / div > < div class = "memdoc" >
< / div >
< / div >
2024-11-23 04:24:16 +08:00
< a id = "aef628dfccdb1361da5546f8b17c510bf" name = "aef628dfccdb1361da5546f8b17c510bf" > < / a >
< h2 class = "memtitle" > < span class = "permalink" > < a href = "#aef628dfccdb1361da5546f8b17c510bf" > ◆   < / a > < / span > row_reduce_simple()< / h2 >
2024-10-15 23:12:17 +08:00
< div class = "memitem" >
< div class = "memproto" >
< div class = "memtemplate" >
2024-11-23 04:24:16 +08:00
template< typename T , typename U , typename Op , typename IdxT = size_t, int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> < / div >
2024-10-15 23:12:17 +08:00
< table class = "memname" >
< tr >
< td class = "memname" > void row_reduce_simple < / td >
< td > (< / td >
< td class = "paramtype" > const device T *< / td > < td class = "paramname" > < span class = "paramname" > < em > in< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > device U *< / td > < td class = "paramname" > < span class = "paramname" > < em > out< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant size_t & < / td > < td class = "paramname" > < span class = "paramname" > < em > reduction_size< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant size_t & < / td > < td class = "paramname" > < span class = "paramname" > < em > out_size< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint3< / td > < td class = "paramname" > < span class = "paramname" > < em > gid< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint3< / td > < td class = "paramname" > < span class = "paramname" > < em > gsize< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint3< / td > < td class = "paramname" > < span class = "paramname" > < em > lid< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint3< / td > < td class = "paramname" > < span class = "paramname" > < em > lsize< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint< / td > < td class = "paramname" > < span class = "paramname" > < em > simd_lane_id< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint< / td > < td class = "paramname" > < span class = "paramname" > < em > simd_per_group< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint< / td > < td class = "paramname" > < span class = "paramname" > < em > simd_group_id< / em > < / span >   )< / td >
< / tr >
< / table >
< / div > < div class = "memdoc" >
< / div >
< / div >
2024-11-23 04:24:16 +08:00
< a id = "aeb49e89f1163cb3093770bb710df9f5e" name = "aeb49e89f1163cb3093770bb710df9f5e" > < / a >
< h2 class = "memtitle" > < span class = "permalink" > < a href = "#aeb49e89f1163cb3093770bb710df9f5e" > ◆   < / a > < / span > row_reduce_small()< / h2 >
2024-10-15 23:12:17 +08:00
< div class = "memitem" >
< div class = "memproto" >
< div class = "memtemplate" >
2024-11-23 04:24:16 +08:00
template< typename T , typename U , typename Op , typename IdxT , int NDIMS, int N_READS = REDUCE_N_READS> < / div >
2024-10-15 23:12:17 +08:00
< table class = "memname" >
< tr >
< td class = "memname" > void row_reduce_small < / td >
< td > (< / td >
< td class = "paramtype" > const device T *< / td > < td class = "paramname" > < span class = "paramname" > < em > in< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > device U *< / td > < td class = "paramname" > < span class = "paramname" > < em > out< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant size_t & < / td > < td class = "paramname" > < span class = "paramname" > < em > row_size< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant size_t & < / td > < td class = "paramname" > < span class = "paramname" > < em > non_row_reductions< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant int *< / td > < td class = "paramname" > < span class = "paramname" > < em > shape< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant size_t *< / td > < td class = "paramname" > < span class = "paramname" > < em > strides< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant int & < / td > < td class = "paramname" > < span class = "paramname" > < em > ndim< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant int *< / td > < td class = "paramname" > < span class = "paramname" > < em > reduce_shape< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant size_t *< / td > < td class = "paramname" > < span class = "paramname" > < em > reduce_strides< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const constant int & < / td > < td class = "paramname" > < span class = "paramname" > < em > reduce_ndim< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint< / td > < td class = "paramname" > < span class = "paramname" > < em > simd_lane_id< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint3< / td > < td class = "paramname" > < span class = "paramname" > < em > gid< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint3< / td > < td class = "paramname" > < span class = "paramname" > < em > gsize< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint3< / td > < td class = "paramname" > < span class = "paramname" > < em > tid< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint3< / td > < td class = "paramname" > < span class = "paramname" > < em > tsize< / em > < / span >   )< / td >
< / tr >
< / table >
< / div > < div class = "memdoc" >
< / div >
< / div >
< a id = "afd80a25fa84e6cc884dcc8698859ade1" name = "afd80a25fa84e6cc884dcc8698859ade1" > < / a >
< h2 class = "memtitle" > < span class = "permalink" > < a href = "#afd80a25fa84e6cc884dcc8698859ade1" > ◆   < / a > < / span > thread_reduce()< / h2 >
< div class = "memitem" >
< div class = "memproto" >
< div class = "memtemplate" >
template< typename T , typename U , typename Op , int N_READS = REDUCE_N_READS> < / div >
< table class = "memname" >
< tr >
< td class = "memname" > METAL_FUNC void thread_reduce < / td >
< td > (< / td >
< td class = "paramtype" > thread U & < / td > < td class = "paramname" > < span class = "paramname" > < em > total< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > const device T *< / td > < td class = "paramname" > < span class = "paramname" > < em > row< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > int< / td > < td class = "paramname" > < span class = "paramname" > < em > blocks< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > int< / td > < td class = "paramname" > < span class = "paramname" > < em > extra< / em > < / span >   )< / td >
< / tr >
< / table >
< / div > < div class = "memdoc" >
< / div >
< / div >
< a id = "aa146bb611069fd2892f03714fd1cc3cf" name = "aa146bb611069fd2892f03714fd1cc3cf" > < / a >
< h2 class = "memtitle" > < span class = "permalink" > < a href = "#aa146bb611069fd2892f03714fd1cc3cf" > ◆   < / a > < / span > threadgroup_reduce()< / h2 >
< div class = "memitem" >
< div class = "memproto" >
< div class = "memtemplate" >
template< typename T , typename U , typename Op , int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> < / div >
< table class = "memname" >
< tr >
< td class = "memname" > METAL_FUNC void threadgroup_reduce < / td >
< td > (< / td >
< td class = "paramtype" > thread U< / td > < td class = "paramname" > < span class = "paramname" > < em > totals< / em > < / span > [N_WRITES], < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > threadgroup U *< / td > < td class = "paramname" > < span class = "paramname" > < em > shared_vals< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint3< / td > < td class = "paramname" > < span class = "paramname" > < em > lid< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint< / td > < td class = "paramname" > < span class = "paramname" > < em > simd_lane_id< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint< / td > < td class = "paramname" > < span class = "paramname" > < em > simd_per_group< / em > < / span > , < / td >
< / tr >
< tr >
< td class = "paramkey" > < / td >
< td > < / td >
< td class = "paramtype" > uint< / td > < td class = "paramname" > < span class = "paramname" > < em > simd_group_id< / em > < / span >   )< / td >
< / tr >
< / table >
< / div > < div class = "memdoc" >
< p > Reduce within the threadgroup. < / p >
< / div >
< / div >
< / div > <!-- contents -->
<!-- start footer part -->
< hr class = "footer" / > < address class = "footer" > < small >
Generated by  < a href = "https://www.doxygen.org/index.html" > < img class = "footer" src = "doxygen.svg" width = "104" height = "31" alt = "doxygen" / > < / a > 1.12.0
< / small > < / address >
< / div > <!-- doc - content -->
< / body >
< / html >