mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
docs update
This commit is contained in:

committed by
CircleCI Docs

parent
b95224115c
commit
6e9288a41c
308
docs/build/html/sort_8h_source.html
vendored
308
docs/build/html/sort_8h_source.html
vendored
@@ -635,14 +635,14 @@ $(function() { codefold.init(0); });
|
||||
<div class="line"><a id="l00522" name="l00522"></a><span class="lineno"> 522</span> <span class="keywordtype">bool</span> ARG_SORT,</div>
|
||||
<div class="line"><a id="l00523" name="l00523"></a><span class="lineno"> 523</span> <span class="keywordtype">short</span> BLOCK_THREADS,</div>
|
||||
<div class="line"><a id="l00524" name="l00524"></a><span class="lineno"> 524</span> <span class="keywordtype">short</span> N_PER_THREAD></div>
|
||||
<div class="line"><a id="l00525" name="l00525"></a><span class="lineno"> 525</span>[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] <span class="keywordtype">void</span></div>
|
||||
<div class="foldopen" id="foldopen00526" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00526" name="l00526"></a><span class="lineno"><a class="line" href="sort_8h.html#a50ae11454e4dfa374a9bd256cdbbf605"> 526</a></span><a class="code hl_function" href="sort_8h.html#a50ae11454e4dfa374a9bd256cdbbf605">mb_block_partition</a>(</div>
|
||||
<div class="line"><a id="l00527" name="l00527"></a><span class="lineno"> 527</span> device idx_t* block_partitions [[buffer(0)]],</div>
|
||||
<div class="line"><a id="l00528" name="l00528"></a><span class="lineno"> 528</span> <span class="keyword">const</span> device val_t* dev_vals [[buffer(1)]],</div>
|
||||
<div class="line"><a id="l00529" name="l00529"></a><span class="lineno"> 529</span> <span class="keyword">const</span> device idx_t* dev_idxs [[buffer(2)]],</div>
|
||||
<div class="line"><a id="l00530" name="l00530"></a><span class="lineno"> 530</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& size_sorted_axis [[buffer(3)]],</div>
|
||||
<div class="line"><a id="l00531" name="l00531"></a><span class="lineno"> 531</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& merge_tiles [[buffer(4)]],</div>
|
||||
<div class="foldopen" id="foldopen00525" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00525" name="l00525"></a><span class="lineno"><a class="line" href="sort_8h.html#a32cbe4163b8b0f5cb2c97b256119a4b2"> 525</a></span>[[kernel]] <span class="keywordtype">void</span> <a class="code hl_function" href="sort_8h.html#a32cbe4163b8b0f5cb2c97b256119a4b2">mb_block_partition</a>(</div>
|
||||
<div class="line"><a id="l00526" name="l00526"></a><span class="lineno"> 526</span> device idx_t* block_partitions [[buffer(0)]],</div>
|
||||
<div class="line"><a id="l00527" name="l00527"></a><span class="lineno"> 527</span> <span class="keyword">const</span> device val_t* dev_vals [[buffer(1)]],</div>
|
||||
<div class="line"><a id="l00528" name="l00528"></a><span class="lineno"> 528</span> <span class="keyword">const</span> device idx_t* dev_idxs [[buffer(2)]],</div>
|
||||
<div class="line"><a id="l00529" name="l00529"></a><span class="lineno"> 529</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& size_sorted_axis [[buffer(3)]],</div>
|
||||
<div class="line"><a id="l00530" name="l00530"></a><span class="lineno"> 530</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& merge_tiles [[buffer(4)]],</div>
|
||||
<div class="line"><a id="l00531" name="l00531"></a><span class="lineno"> 531</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& n_blocks [[buffer(5)]],</div>
|
||||
<div class="line"><a id="l00532" name="l00532"></a><span class="lineno"> 532</span> uint3 tid [[threadgroup_position_in_grid]],</div>
|
||||
<div class="line"><a id="l00533" name="l00533"></a><span class="lineno"> 533</span> uint3 lid [[thread_position_in_threadgroup]],</div>
|
||||
<div class="line"><a id="l00534" name="l00534"></a><span class="lineno"> 534</span> uint3 tgp_dims [[threads_per_threadgroup]]) {</div>
|
||||
@@ -657,152 +657,158 @@ $(function() { codefold.init(0); });
|
||||
<div class="line"><a id="l00543" name="l00543"></a><span class="lineno"> 543</span> dev_vals += tid.y * size_sorted_axis;</div>
|
||||
<div class="line"><a id="l00544" name="l00544"></a><span class="lineno"> 544</span> dev_idxs += tid.y * size_sorted_axis;</div>
|
||||
<div class="line"><a id="l00545" name="l00545"></a><span class="lineno"> 545</span> </div>
|
||||
<div class="line"><a id="l00546" name="l00546"></a><span class="lineno"> 546</span> <span class="comment">// Find location in merge step</span></div>
|
||||
<div class="line"><a id="l00547" name="l00547"></a><span class="lineno"> 547</span> <span class="keywordtype">int</span> merge_group = lid.x / merge_tiles;</div>
|
||||
<div class="line"><a id="l00548" name="l00548"></a><span class="lineno"> 548</span> <span class="keywordtype">int</span> merge_lane = lid.x % merge_tiles;</div>
|
||||
<div class="line"><a id="l00549" name="l00549"></a><span class="lineno"> 549</span> </div>
|
||||
<div class="line"><a id="l00550" name="l00550"></a><span class="lineno"> 550</span> <span class="keywordtype">int</span> sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;</div>
|
||||
<div class="line"><a id="l00551" name="l00551"></a><span class="lineno"> 551</span> <span class="keywordtype">int</span> sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;</div>
|
||||
<div class="line"><a id="l00552" name="l00552"></a><span class="lineno"> 552</span> </div>
|
||||
<div class="line"><a id="l00553" name="l00553"></a><span class="lineno"> 553</span> <span class="keywordtype">int</span> A_st = <a class="code hl_function" href="namespacemetal.html#a6653b28c9473087141eddce39878d4d3">min</a>(size_sorted_axis, sort_st);</div>
|
||||
<div class="line"><a id="l00554" name="l00554"></a><span class="lineno"> 554</span> <span class="keywordtype">int</span> A_ed = <a class="code hl_function" href="namespacemetal.html#a6653b28c9473087141eddce39878d4d3">min</a>(size_sorted_axis, sort_st + sort_sz / 2);</div>
|
||||
<div class="line"><a id="l00555" name="l00555"></a><span class="lineno"> 555</span> <span class="keywordtype">int</span> B_st = A_ed;</div>
|
||||
<div class="line"><a id="l00556" name="l00556"></a><span class="lineno"> 556</span> <span class="keywordtype">int</span> B_ed = <a class="code hl_function" href="namespacemetal.html#a6653b28c9473087141eddce39878d4d3">min</a>(size_sorted_axis, B_st + sort_sz / 2);</div>
|
||||
<div class="line"><a id="l00557" name="l00557"></a><span class="lineno"> 557</span> </div>
|
||||
<div class="line"><a id="l00558" name="l00558"></a><span class="lineno"> 558</span> <span class="keywordtype">int</span> partition_at = <a class="code hl_function" href="namespacemetal.html#a6653b28c9473087141eddce39878d4d3">min</a>(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);</div>
|
||||
<div class="line"><a id="l00559" name="l00559"></a><span class="lineno"> 559</span> <span class="keywordtype">int</span> partition = sort_kernel::merge_partition(</div>
|
||||
<div class="line"><a id="l00560" name="l00560"></a><span class="lineno"> 560</span> dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);</div>
|
||||
<div class="line"><a id="l00561" name="l00561"></a><span class="lineno"> 561</span> </div>
|
||||
<div class="line"><a id="l00562" name="l00562"></a><span class="lineno"> 562</span> block_partitions[lid.x] = A_st + partition;</div>
|
||||
<div class="line"><a id="l00563" name="l00563"></a><span class="lineno"> 563</span>}</div>
|
||||
<div class="line"><a id="l00546" name="l00546"></a><span class="lineno"> 546</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = lid.x; i <= n_blocks; i += tgp_dims.x) {</div>
|
||||
<div class="line"><a id="l00547" name="l00547"></a><span class="lineno"> 547</span> <span class="comment">// Find location in merge step</span></div>
|
||||
<div class="line"><a id="l00548" name="l00548"></a><span class="lineno"> 548</span> <span class="keywordtype">int</span> merge_group = i / merge_tiles;</div>
|
||||
<div class="line"><a id="l00549" name="l00549"></a><span class="lineno"> 549</span> <span class="keywordtype">int</span> merge_lane = i % merge_tiles;</div>
|
||||
<div class="line"><a id="l00550" name="l00550"></a><span class="lineno"> 550</span> </div>
|
||||
<div class="line"><a id="l00551" name="l00551"></a><span class="lineno"> 551</span> <span class="keywordtype">int</span> sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;</div>
|
||||
<div class="line"><a id="l00552" name="l00552"></a><span class="lineno"> 552</span> <span class="keywordtype">int</span> sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;</div>
|
||||
<div class="line"><a id="l00553" name="l00553"></a><span class="lineno"> 553</span> </div>
|
||||
<div class="line"><a id="l00554" name="l00554"></a><span class="lineno"> 554</span> <span class="keywordtype">int</span> A_st = <a class="code hl_function" href="namespacemetal.html#a6653b28c9473087141eddce39878d4d3">min</a>(size_sorted_axis, sort_st);</div>
|
||||
<div class="line"><a id="l00555" name="l00555"></a><span class="lineno"> 555</span> <span class="keywordtype">int</span> A_ed = <a class="code hl_function" href="namespacemetal.html#a6653b28c9473087141eddce39878d4d3">min</a>(size_sorted_axis, sort_st + sort_sz / 2);</div>
|
||||
<div class="line"><a id="l00556" name="l00556"></a><span class="lineno"> 556</span> <span class="keywordtype">int</span> B_st = A_ed;</div>
|
||||
<div class="line"><a id="l00557" name="l00557"></a><span class="lineno"> 557</span> <span class="keywordtype">int</span> B_ed = <a class="code hl_function" href="namespacemetal.html#a6653b28c9473087141eddce39878d4d3">min</a>(size_sorted_axis, B_st + sort_sz / 2);</div>
|
||||
<div class="line"><a id="l00558" name="l00558"></a><span class="lineno"> 558</span> </div>
|
||||
<div class="line"><a id="l00559" name="l00559"></a><span class="lineno"> 559</span> <span class="keywordtype">int</span> partition_at = <a class="code hl_function" href="namespacemetal.html#a6653b28c9473087141eddce39878d4d3">min</a>(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);</div>
|
||||
<div class="line"><a id="l00560" name="l00560"></a><span class="lineno"> 560</span> <span class="keywordtype">int</span> partition = sort_kernel::merge_partition(</div>
|
||||
<div class="line"><a id="l00561" name="l00561"></a><span class="lineno"> 561</span> dev_vals + A_st,</div>
|
||||
<div class="line"><a id="l00562" name="l00562"></a><span class="lineno"> 562</span> dev_vals + B_st,</div>
|
||||
<div class="line"><a id="l00563" name="l00563"></a><span class="lineno"> 563</span> A_ed - A_st,</div>
|
||||
<div class="line"><a id="l00564" name="l00564"></a><span class="lineno"> 564</span> B_ed - B_st,</div>
|
||||
<div class="line"><a id="l00565" name="l00565"></a><span class="lineno"> 565</span> partition_at);</div>
|
||||
<div class="line"><a id="l00566" name="l00566"></a><span class="lineno"> 566</span> </div>
|
||||
<div class="line"><a id="l00567" name="l00567"></a><span class="lineno"> 567</span> block_partitions[i] = A_st + partition;</div>
|
||||
<div class="line"><a id="l00568" name="l00568"></a><span class="lineno"> 568</span> }</div>
|
||||
<div class="line"><a id="l00569" name="l00569"></a><span class="lineno"> 569</span>}</div>
|
||||
</div>
|
||||
<div class="line"><a id="l00564" name="l00564"></a><span class="lineno"> 564</span> </div>
|
||||
<div class="line"><a id="l00565" name="l00565"></a><span class="lineno"> 565</span><span class="keyword">template</span> <</div>
|
||||
<div class="line"><a id="l00566" name="l00566"></a><span class="lineno"> 566</span> <span class="keyword">typename</span> val_t,</div>
|
||||
<div class="line"><a id="l00567" name="l00567"></a><span class="lineno"> 567</span> <span class="keyword">typename</span> idx_t,</div>
|
||||
<div class="line"><a id="l00568" name="l00568"></a><span class="lineno"> 568</span> <span class="keywordtype">bool</span> ARG_SORT,</div>
|
||||
<div class="line"><a id="l00569" name="l00569"></a><span class="lineno"> 569</span> <span class="keywordtype">short</span> BLOCK_THREADS,</div>
|
||||
<div class="line"><a id="l00570" name="l00570"></a><span class="lineno"> 570</span> <span class="keywordtype">short</span> N_PER_THREAD,</div>
|
||||
<div class="line"><a id="l00571" name="l00571"></a><span class="lineno"> 571</span> <span class="keyword">typename</span> CompareOp = <a class="code hl_struct" href="struct_less_than.html">LessThan<val_t></a>></div>
|
||||
<div class="line"><a id="l00572" name="l00572"></a><span class="lineno"> 572</span>[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] <span class="keywordtype">void</span></div>
|
||||
<div class="foldopen" id="foldopen00573" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00573" name="l00573"></a><span class="lineno"><a class="line" href="sort_8h.html#ab381cd57f344bc7304ab580bfdc78807"> 573</a></span><a class="code hl_function" href="sort_8h.html#ab381cd57f344bc7304ab580bfdc78807">mb_block_merge</a>(</div>
|
||||
<div class="line"><a id="l00574" name="l00574"></a><span class="lineno"> 574</span> <span class="keyword">const</span> device idx_t* block_partitions [[buffer(0)]],</div>
|
||||
<div class="line"><a id="l00575" name="l00575"></a><span class="lineno"> 575</span> <span class="keyword">const</span> device val_t* dev_vals_in [[buffer(1)]],</div>
|
||||
<div class="line"><a id="l00576" name="l00576"></a><span class="lineno"> 576</span> <span class="keyword">const</span> device idx_t* dev_idxs_in [[buffer(2)]],</div>
|
||||
<div class="line"><a id="l00577" name="l00577"></a><span class="lineno"> 577</span> device val_t* dev_vals_out [[buffer(3)]],</div>
|
||||
<div class="line"><a id="l00578" name="l00578"></a><span class="lineno"> 578</span> device idx_t* dev_idxs_out [[buffer(4)]],</div>
|
||||
<div class="line"><a id="l00579" name="l00579"></a><span class="lineno"> 579</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& size_sorted_axis [[buffer(5)]],</div>
|
||||
<div class="line"><a id="l00580" name="l00580"></a><span class="lineno"> 580</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& merge_tiles [[buffer(6)]],</div>
|
||||
<div class="line"><a id="l00581" name="l00581"></a><span class="lineno"> 581</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& num_tiles [[buffer(7)]],</div>
|
||||
<div class="line"><a id="l00582" name="l00582"></a><span class="lineno"> 582</span> uint3 tid [[threadgroup_position_in_grid]],</div>
|
||||
<div class="line"><a id="l00583" name="l00583"></a><span class="lineno"> 583</span> uint3 lid [[thread_position_in_threadgroup]]) {</div>
|
||||
<div class="line"><a id="l00584" name="l00584"></a><span class="lineno"> 584</span> <span class="keyword">using </span>sort_kernel = <a class="code hl_struct" href="struct_kernel_multi_block_merge_sort.html">KernelMultiBlockMergeSort</a><</div>
|
||||
<div class="line"><a id="l00585" name="l00585"></a><span class="lineno"> 585</span> val_t,</div>
|
||||
<div class="line"><a id="l00586" name="l00586"></a><span class="lineno"> 586</span> idx_t,</div>
|
||||
<div class="line"><a id="l00587" name="l00587"></a><span class="lineno"> 587</span> ARG_SORT,</div>
|
||||
<div class="line"><a id="l00588" name="l00588"></a><span class="lineno"> 588</span> BLOCK_THREADS,</div>
|
||||
<div class="line"><a id="l00589" name="l00589"></a><span class="lineno"> 589</span> N_PER_THREAD,</div>
|
||||
<div class="line"><a id="l00590" name="l00590"></a><span class="lineno"> 590</span> CompareOp>;</div>
|
||||
<div class="line"><a id="l00591" name="l00591"></a><span class="lineno"> 591</span> </div>
|
||||
<div class="line"><a id="l00592" name="l00592"></a><span class="lineno"> 592</span> <span class="keyword">using </span>block_sort_t = <span class="keyword">typename</span> sort_kernel::block_merge_sort_t;</div>
|
||||
<div class="line"><a id="l00593" name="l00593"></a><span class="lineno"> 593</span> </div>
|
||||
<div class="line"><a id="l00594" name="l00594"></a><span class="lineno"> 594</span> block_partitions += tid.y * (num_tiles + 1);</div>
|
||||
<div class="line"><a id="l00595" name="l00595"></a><span class="lineno"> 595</span> dev_vals_in += tid.y * size_sorted_axis;</div>
|
||||
<div class="line"><a id="l00596" name="l00596"></a><span class="lineno"> 596</span> dev_idxs_in += tid.y * size_sorted_axis;</div>
|
||||
<div class="line"><a id="l00597" name="l00597"></a><span class="lineno"> 597</span> dev_vals_out += tid.y * size_sorted_axis;</div>
|
||||
<div class="line"><a id="l00598" name="l00598"></a><span class="lineno"> 598</span> dev_idxs_out += tid.y * size_sorted_axis;</div>
|
||||
<div class="line"><a id="l00570" name="l00570"></a><span class="lineno"> 570</span> </div>
|
||||
<div class="line"><a id="l00571" name="l00571"></a><span class="lineno"> 571</span><span class="keyword">template</span> <</div>
|
||||
<div class="line"><a id="l00572" name="l00572"></a><span class="lineno"> 572</span> <span class="keyword">typename</span> val_t,</div>
|
||||
<div class="line"><a id="l00573" name="l00573"></a><span class="lineno"> 573</span> <span class="keyword">typename</span> idx_t,</div>
|
||||
<div class="line"><a id="l00574" name="l00574"></a><span class="lineno"> 574</span> <span class="keywordtype">bool</span> ARG_SORT,</div>
|
||||
<div class="line"><a id="l00575" name="l00575"></a><span class="lineno"> 575</span> <span class="keywordtype">short</span> BLOCK_THREADS,</div>
|
||||
<div class="line"><a id="l00576" name="l00576"></a><span class="lineno"> 576</span> <span class="keywordtype">short</span> N_PER_THREAD,</div>
|
||||
<div class="line"><a id="l00577" name="l00577"></a><span class="lineno"> 577</span> <span class="keyword">typename</span> CompareOp = <a class="code hl_struct" href="struct_less_than.html">LessThan<val_t></a>></div>
|
||||
<div class="line"><a id="l00578" name="l00578"></a><span class="lineno"> 578</span>[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] <span class="keywordtype">void</span></div>
|
||||
<div class="foldopen" id="foldopen00579" data-start="{" data-end="}">
|
||||
<div class="line"><a id="l00579" name="l00579"></a><span class="lineno"><a class="line" href="sort_8h.html#ab381cd57f344bc7304ab580bfdc78807"> 579</a></span><a class="code hl_function" href="sort_8h.html#ab381cd57f344bc7304ab580bfdc78807">mb_block_merge</a>(</div>
|
||||
<div class="line"><a id="l00580" name="l00580"></a><span class="lineno"> 580</span> <span class="keyword">const</span> device idx_t* block_partitions [[buffer(0)]],</div>
|
||||
<div class="line"><a id="l00581" name="l00581"></a><span class="lineno"> 581</span> <span class="keyword">const</span> device val_t* dev_vals_in [[buffer(1)]],</div>
|
||||
<div class="line"><a id="l00582" name="l00582"></a><span class="lineno"> 582</span> <span class="keyword">const</span> device idx_t* dev_idxs_in [[buffer(2)]],</div>
|
||||
<div class="line"><a id="l00583" name="l00583"></a><span class="lineno"> 583</span> device val_t* dev_vals_out [[buffer(3)]],</div>
|
||||
<div class="line"><a id="l00584" name="l00584"></a><span class="lineno"> 584</span> device idx_t* dev_idxs_out [[buffer(4)]],</div>
|
||||
<div class="line"><a id="l00585" name="l00585"></a><span class="lineno"> 585</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& size_sorted_axis [[buffer(5)]],</div>
|
||||
<div class="line"><a id="l00586" name="l00586"></a><span class="lineno"> 586</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& merge_tiles [[buffer(6)]],</div>
|
||||
<div class="line"><a id="l00587" name="l00587"></a><span class="lineno"> 587</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>& num_tiles [[buffer(7)]],</div>
|
||||
<div class="line"><a id="l00588" name="l00588"></a><span class="lineno"> 588</span> uint3 tid [[threadgroup_position_in_grid]],</div>
|
||||
<div class="line"><a id="l00589" name="l00589"></a><span class="lineno"> 589</span> uint3 lid [[thread_position_in_threadgroup]]) {</div>
|
||||
<div class="line"><a id="l00590" name="l00590"></a><span class="lineno"> 590</span> <span class="keyword">using </span>sort_kernel = <a class="code hl_struct" href="struct_kernel_multi_block_merge_sort.html">KernelMultiBlockMergeSort</a><</div>
|
||||
<div class="line"><a id="l00591" name="l00591"></a><span class="lineno"> 591</span> val_t,</div>
|
||||
<div class="line"><a id="l00592" name="l00592"></a><span class="lineno"> 592</span> idx_t,</div>
|
||||
<div class="line"><a id="l00593" name="l00593"></a><span class="lineno"> 593</span> ARG_SORT,</div>
|
||||
<div class="line"><a id="l00594" name="l00594"></a><span class="lineno"> 594</span> BLOCK_THREADS,</div>
|
||||
<div class="line"><a id="l00595" name="l00595"></a><span class="lineno"> 595</span> N_PER_THREAD,</div>
|
||||
<div class="line"><a id="l00596" name="l00596"></a><span class="lineno"> 596</span> CompareOp>;</div>
|
||||
<div class="line"><a id="l00597" name="l00597"></a><span class="lineno"> 597</span> </div>
|
||||
<div class="line"><a id="l00598" name="l00598"></a><span class="lineno"> 598</span> <span class="keyword">using </span>block_sort_t = <span class="keyword">typename</span> sort_kernel::block_merge_sort_t;</div>
|
||||
<div class="line"><a id="l00599" name="l00599"></a><span class="lineno"> 599</span> </div>
|
||||
<div class="line"><a id="l00600" name="l00600"></a><span class="lineno"> 600</span> <span class="keywordtype">int</span> block_idx = tid.x;</div>
|
||||
<div class="line"><a id="l00601" name="l00601"></a><span class="lineno"> 601</span> <span class="keywordtype">int</span> merge_group = block_idx / merge_tiles;</div>
|
||||
<div class="line"><a id="l00602" name="l00602"></a><span class="lineno"> 602</span> <span class="keywordtype">int</span> sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;</div>
|
||||
<div class="line"><a id="l00603" name="l00603"></a><span class="lineno"> 603</span> <span class="keywordtype">int</span> sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;</div>
|
||||
<div class="line"><a id="l00604" name="l00604"></a><span class="lineno"> 604</span> <span class="keywordtype">int</span> sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;</div>
|
||||
<div class="line"><a id="l00600" name="l00600"></a><span class="lineno"> 600</span> block_partitions += tid.y * (num_tiles + 1);</div>
|
||||
<div class="line"><a id="l00601" name="l00601"></a><span class="lineno"> 601</span> dev_vals_in += tid.y * size_sorted_axis;</div>
|
||||
<div class="line"><a id="l00602" name="l00602"></a><span class="lineno"> 602</span> dev_idxs_in += tid.y * size_sorted_axis;</div>
|
||||
<div class="line"><a id="l00603" name="l00603"></a><span class="lineno"> 603</span> dev_vals_out += tid.y * size_sorted_axis;</div>
|
||||
<div class="line"><a id="l00604" name="l00604"></a><span class="lineno"> 604</span> dev_idxs_out += tid.y * size_sorted_axis;</div>
|
||||
<div class="line"><a id="l00605" name="l00605"></a><span class="lineno"> 605</span> </div>
|
||||
<div class="line"><a id="l00606" name="l00606"></a><span class="lineno"> 606</span> <span class="keywordtype">int</span> A_st = block_partitions[block_idx + 0];</div>
|
||||
<div class="line"><a id="l00607" name="l00607"></a><span class="lineno"> 607</span> <span class="keywordtype">int</span> A_ed = block_partitions[block_idx + 1];</div>
|
||||
<div class="line"><a id="l00608" name="l00608"></a><span class="lineno"> 608</span> <span class="keywordtype">int</span> B_st = <a class="code hl_function" href="namespacemetal.html#a6653b28c9473087141eddce39878d4d3">min</a>(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);</div>
|
||||
<div class="line"><a id="l00609" name="l00609"></a><span class="lineno"> 609</span> <span class="keywordtype">int</span> B_ed = <a class="code hl_function" href="namespacemetal.html#a6653b28c9473087141eddce39878d4d3">min</a>(</div>
|
||||
<div class="line"><a id="l00610" name="l00610"></a><span class="lineno"> 610</span> size_sorted_axis,</div>
|
||||
<div class="line"><a id="l00611" name="l00611"></a><span class="lineno"> 611</span> 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);</div>
|
||||
<div class="line"><a id="l00612" name="l00612"></a><span class="lineno"> 612</span> </div>
|
||||
<div class="line"><a id="l00613" name="l00613"></a><span class="lineno"> 613</span> <span class="keywordflow">if</span> ((block_idx % merge_tiles) == merge_tiles - 1) {</div>
|
||||
<div class="line"><a id="l00614" name="l00614"></a><span class="lineno"> 614</span> A_ed = <a class="code hl_function" href="namespacemetal.html#a6653b28c9473087141eddce39878d4d3">min</a>(size_sorted_axis, sort_st + sort_sz / 2);</div>
|
||||
<div class="line"><a id="l00615" name="l00615"></a><span class="lineno"> 615</span> B_ed = <a class="code hl_function" href="namespacemetal.html#a6653b28c9473087141eddce39878d4d3">min</a>(size_sorted_axis, sort_st + sort_sz);</div>
|
||||
<div class="line"><a id="l00616" name="l00616"></a><span class="lineno"> 616</span> }</div>
|
||||
<div class="line"><a id="l00617" name="l00617"></a><span class="lineno"> 617</span> </div>
|
||||
<div class="line"><a id="l00618" name="l00618"></a><span class="lineno"> 618</span> <span class="keywordtype">int</span> A_sz = A_ed - A_st;</div>
|
||||
<div class="line"><a id="l00619" name="l00619"></a><span class="lineno"> 619</span> <span class="keywordtype">int</span> B_sz = B_ed - B_st;</div>
|
||||
<div class="line"><a id="l00620" name="l00620"></a><span class="lineno"> 620</span> </div>
|
||||
<div class="line"><a id="l00621" name="l00621"></a><span class="lineno"> 621</span> <span class="comment">// Load from global memory</span></div>
|
||||
<div class="line"><a id="l00622" name="l00622"></a><span class="lineno"> 622</span> thread val_t thread_vals[N_PER_THREAD];</div>
|
||||
<div class="line"><a id="l00623" name="l00623"></a><span class="lineno"> 623</span> thread idx_t thread_idxs[N_PER_THREAD];</div>
|
||||
<div class="line"><a id="l00624" name="l00624"></a><span class="lineno"> 624</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < N_PER_THREAD; i++) {</div>
|
||||
<div class="line"><a id="l00625" name="l00625"></a><span class="lineno"> 625</span> <span class="keywordtype">int</span> idx = BLOCK_THREADS * i + lid.x;</div>
|
||||
<div class="line"><a id="l00626" name="l00626"></a><span class="lineno"> 626</span> <span class="keywordflow">if</span> (idx < (A_sz + B_sz)) {</div>
|
||||
<div class="line"><a id="l00627" name="l00627"></a><span class="lineno"> 627</span> thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]</div>
|
||||
<div class="line"><a id="l00628" name="l00628"></a><span class="lineno"> 628</span> : dev_vals_in[B_st + idx - A_sz];</div>
|
||||
<div class="line"><a id="l00629" name="l00629"></a><span class="lineno"> 629</span> thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]</div>
|
||||
<div class="line"><a id="l00630" name="l00630"></a><span class="lineno"> 630</span> : dev_idxs_in[B_st + idx - A_sz];</div>
|
||||
<div class="line"><a id="l00631" name="l00631"></a><span class="lineno"> 631</span> } <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00632" name="l00632"></a><span class="lineno"> 632</span> thread_vals[i] = CompareOp::init;</div>
|
||||
<div class="line"><a id="l00633" name="l00633"></a><span class="lineno"> 633</span> thread_idxs[i] = 0;</div>
|
||||
<div class="line"><a id="l00634" name="l00634"></a><span class="lineno"> 634</span> }</div>
|
||||
<div class="line"><a id="l00635" name="l00635"></a><span class="lineno"> 635</span> }</div>
|
||||
<div class="line"><a id="l00636" name="l00636"></a><span class="lineno"> 636</span> </div>
|
||||
<div class="line"><a id="l00637" name="l00637"></a><span class="lineno"> 637</span> <span class="comment">// Write to shared memory</span></div>
|
||||
<div class="line"><a id="l00638" name="l00638"></a><span class="lineno"> 638</span> threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];</div>
|
||||
<div class="line"><a id="l00639" name="l00639"></a><span class="lineno"> 639</span> threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];</div>
|
||||
<div class="line"><a id="l00640" name="l00640"></a><span class="lineno"> 640</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00641" name="l00641"></a><span class="lineno"> 641</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < N_PER_THREAD; i++) {</div>
|
||||
<div class="line"><a id="l00642" name="l00642"></a><span class="lineno"> 642</span> <span class="keywordtype">int</span> idx = BLOCK_THREADS * i + lid.x;</div>
|
||||
<div class="line"><a id="l00643" name="l00643"></a><span class="lineno"> 643</span> tgp_vals[idx] = thread_vals[i];</div>
|
||||
<div class="line"><a id="l00644" name="l00644"></a><span class="lineno"> 644</span> tgp_idxs[idx] = thread_idxs[i];</div>
|
||||
<div class="line"><a id="l00645" name="l00645"></a><span class="lineno"> 645</span> }</div>
|
||||
<div class="line"><a id="l00606" name="l00606"></a><span class="lineno"> 606</span> <span class="keywordtype">int</span> block_idx = tid.x;</div>
|
||||
<div class="line"><a id="l00607" name="l00607"></a><span class="lineno"> 607</span> <span class="keywordtype">int</span> merge_group = block_idx / merge_tiles;</div>
|
||||
<div class="line"><a id="l00608" name="l00608"></a><span class="lineno"> 608</span> <span class="keywordtype">int</span> sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;</div>
|
||||
<div class="line"><a id="l00609" name="l00609"></a><span class="lineno"> 609</span> <span class="keywordtype">int</span> sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;</div>
|
||||
<div class="line"><a id="l00610" name="l00610"></a><span class="lineno"> 610</span> <span class="keywordtype">int</span> sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;</div>
|
||||
<div class="line"><a id="l00611" name="l00611"></a><span class="lineno"> 611</span> </div>
|
||||
<div class="line"><a id="l00612" name="l00612"></a><span class="lineno"> 612</span> <span class="keywordtype">int</span> A_st = block_partitions[block_idx + 0];</div>
|
||||
<div class="line"><a id="l00613" name="l00613"></a><span class="lineno"> 613</span> <span class="keywordtype">int</span> A_ed = block_partitions[block_idx + 1];</div>
|
||||
<div class="line"><a id="l00614" name="l00614"></a><span class="lineno"> 614</span> <span class="keywordtype">int</span> B_st = <a class="code hl_function" href="namespacemetal.html#a6653b28c9473087141eddce39878d4d3">min</a>(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);</div>
|
||||
<div class="line"><a id="l00615" name="l00615"></a><span class="lineno"> 615</span> <span class="keywordtype">int</span> B_ed = <a class="code hl_function" href="namespacemetal.html#a6653b28c9473087141eddce39878d4d3">min</a>(</div>
|
||||
<div class="line"><a id="l00616" name="l00616"></a><span class="lineno"> 616</span> size_sorted_axis,</div>
|
||||
<div class="line"><a id="l00617" name="l00617"></a><span class="lineno"> 617</span> 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);</div>
|
||||
<div class="line"><a id="l00618" name="l00618"></a><span class="lineno"> 618</span> </div>
|
||||
<div class="line"><a id="l00619" name="l00619"></a><span class="lineno"> 619</span> <span class="keywordflow">if</span> ((block_idx % merge_tiles) == merge_tiles - 1) {</div>
|
||||
<div class="line"><a id="l00620" name="l00620"></a><span class="lineno"> 620</span> A_ed = <a class="code hl_function" href="namespacemetal.html#a6653b28c9473087141eddce39878d4d3">min</a>(size_sorted_axis, sort_st + sort_sz / 2);</div>
|
||||
<div class="line"><a id="l00621" name="l00621"></a><span class="lineno"> 621</span> B_ed = <a class="code hl_function" href="namespacemetal.html#a6653b28c9473087141eddce39878d4d3">min</a>(size_sorted_axis, sort_st + sort_sz);</div>
|
||||
<div class="line"><a id="l00622" name="l00622"></a><span class="lineno"> 622</span> }</div>
|
||||
<div class="line"><a id="l00623" name="l00623"></a><span class="lineno"> 623</span> </div>
|
||||
<div class="line"><a id="l00624" name="l00624"></a><span class="lineno"> 624</span> <span class="keywordtype">int</span> A_sz = A_ed - A_st;</div>
|
||||
<div class="line"><a id="l00625" name="l00625"></a><span class="lineno"> 625</span> <span class="keywordtype">int</span> B_sz = B_ed - B_st;</div>
|
||||
<div class="line"><a id="l00626" name="l00626"></a><span class="lineno"> 626</span> </div>
|
||||
<div class="line"><a id="l00627" name="l00627"></a><span class="lineno"> 627</span> <span class="comment">// Load from global memory</span></div>
|
||||
<div class="line"><a id="l00628" name="l00628"></a><span class="lineno"> 628</span> thread val_t thread_vals[N_PER_THREAD];</div>
|
||||
<div class="line"><a id="l00629" name="l00629"></a><span class="lineno"> 629</span> thread idx_t thread_idxs[N_PER_THREAD];</div>
|
||||
<div class="line"><a id="l00630" name="l00630"></a><span class="lineno"> 630</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < N_PER_THREAD; i++) {</div>
|
||||
<div class="line"><a id="l00631" name="l00631"></a><span class="lineno"> 631</span> <span class="keywordtype">int</span> idx = BLOCK_THREADS * i + lid.x;</div>
|
||||
<div class="line"><a id="l00632" name="l00632"></a><span class="lineno"> 632</span> <span class="keywordflow">if</span> (idx < (A_sz + B_sz)) {</div>
|
||||
<div class="line"><a id="l00633" name="l00633"></a><span class="lineno"> 633</span> thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]</div>
|
||||
<div class="line"><a id="l00634" name="l00634"></a><span class="lineno"> 634</span> : dev_vals_in[B_st + idx - A_sz];</div>
|
||||
<div class="line"><a id="l00635" name="l00635"></a><span class="lineno"> 635</span> thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]</div>
|
||||
<div class="line"><a id="l00636" name="l00636"></a><span class="lineno"> 636</span> : dev_idxs_in[B_st + idx - A_sz];</div>
|
||||
<div class="line"><a id="l00637" name="l00637"></a><span class="lineno"> 637</span> } <span class="keywordflow">else</span> {</div>
|
||||
<div class="line"><a id="l00638" name="l00638"></a><span class="lineno"> 638</span> thread_vals[i] = CompareOp::init;</div>
|
||||
<div class="line"><a id="l00639" name="l00639"></a><span class="lineno"> 639</span> thread_idxs[i] = 0;</div>
|
||||
<div class="line"><a id="l00640" name="l00640"></a><span class="lineno"> 640</span> }</div>
|
||||
<div class="line"><a id="l00641" name="l00641"></a><span class="lineno"> 641</span> }</div>
|
||||
<div class="line"><a id="l00642" name="l00642"></a><span class="lineno"> 642</span> </div>
|
||||
<div class="line"><a id="l00643" name="l00643"></a><span class="lineno"> 643</span> <span class="comment">// Write to shared memory</span></div>
|
||||
<div class="line"><a id="l00644" name="l00644"></a><span class="lineno"> 644</span> threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];</div>
|
||||
<div class="line"><a id="l00645" name="l00645"></a><span class="lineno"> 645</span> threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];</div>
|
||||
<div class="line"><a id="l00646" name="l00646"></a><span class="lineno"> 646</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00647" name="l00647"></a><span class="lineno"> 647</span> </div>
|
||||
<div class="line"><a id="l00648" name="l00648"></a><span class="lineno"> 648</span> <span class="comment">// Merge</span></div>
|
||||
<div class="line"><a id="l00649" name="l00649"></a><span class="lineno"> 649</span> <span class="keywordtype">int</span> sort_md_local = <a class="code hl_function" href="namespacemetal.html#a6653b28c9473087141eddce39878d4d3">min</a>(A_sz + B_sz, N_PER_THREAD * <span class="keywordtype">int</span>(lid.x));</div>
|
||||
<div class="line"><a id="l00650" name="l00650"></a><span class="lineno"> 650</span> </div>
|
||||
<div class="line"><a id="l00651" name="l00651"></a><span class="lineno"> 651</span> <span class="keywordtype">int</span> A_st_local = block_sort_t::merge_partition(</div>
|
||||
<div class="line"><a id="l00652" name="l00652"></a><span class="lineno"> 652</span> tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);</div>
|
||||
<div class="line"><a id="l00653" name="l00653"></a><span class="lineno"> 653</span> <span class="keywordtype">int</span> A_ed_local = A_sz;</div>
|
||||
<div class="line"><a id="l00654" name="l00654"></a><span class="lineno"> 654</span> </div>
|
||||
<div class="line"><a id="l00655" name="l00655"></a><span class="lineno"> 655</span> <span class="keywordtype">int</span> B_st_local = sort_md_local - A_st_local;</div>
|
||||
<div class="line"><a id="l00656" name="l00656"></a><span class="lineno"> 656</span> <span class="keywordtype">int</span> B_ed_local = B_sz;</div>
|
||||
<div class="line"><a id="l00657" name="l00657"></a><span class="lineno"> 657</span> </div>
|
||||
<div class="line"><a id="l00658" name="l00658"></a><span class="lineno"> 658</span> <span class="keywordtype">int</span> A_sz_local = A_ed_local - A_st_local;</div>
|
||||
<div class="line"><a id="l00659" name="l00659"></a><span class="lineno"> 659</span> <span class="keywordtype">int</span> B_sz_local = B_ed_local - B_st_local;</div>
|
||||
<div class="line"><a id="l00647" name="l00647"></a><span class="lineno"> 647</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < N_PER_THREAD; i++) {</div>
|
||||
<div class="line"><a id="l00648" name="l00648"></a><span class="lineno"> 648</span> <span class="keywordtype">int</span> idx = BLOCK_THREADS * i + lid.x;</div>
|
||||
<div class="line"><a id="l00649" name="l00649"></a><span class="lineno"> 649</span> tgp_vals[idx] = thread_vals[i];</div>
|
||||
<div class="line"><a id="l00650" name="l00650"></a><span class="lineno"> 650</span> tgp_idxs[idx] = thread_idxs[i];</div>
|
||||
<div class="line"><a id="l00651" name="l00651"></a><span class="lineno"> 651</span> }</div>
|
||||
<div class="line"><a id="l00652" name="l00652"></a><span class="lineno"> 652</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00653" name="l00653"></a><span class="lineno"> 653</span> </div>
|
||||
<div class="line"><a id="l00654" name="l00654"></a><span class="lineno"> 654</span> <span class="comment">// Merge</span></div>
|
||||
<div class="line"><a id="l00655" name="l00655"></a><span class="lineno"> 655</span> <span class="keywordtype">int</span> sort_md_local = <a class="code hl_function" href="namespacemetal.html#a6653b28c9473087141eddce39878d4d3">min</a>(A_sz + B_sz, N_PER_THREAD * <span class="keywordtype">int</span>(lid.x));</div>
|
||||
<div class="line"><a id="l00656" name="l00656"></a><span class="lineno"> 656</span> </div>
|
||||
<div class="line"><a id="l00657" name="l00657"></a><span class="lineno"> 657</span> <span class="keywordtype">int</span> A_st_local = block_sort_t::merge_partition(</div>
|
||||
<div class="line"><a id="l00658" name="l00658"></a><span class="lineno"> 658</span> tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);</div>
|
||||
<div class="line"><a id="l00659" name="l00659"></a><span class="lineno"> 659</span> <span class="keywordtype">int</span> A_ed_local = A_sz;</div>
|
||||
<div class="line"><a id="l00660" name="l00660"></a><span class="lineno"> 660</span> </div>
|
||||
<div class="line"><a id="l00661" name="l00661"></a><span class="lineno"> 661</span> <span class="comment">// Do merge</span></div>
|
||||
<div class="line"><a id="l00662" name="l00662"></a><span class="lineno"> 662</span> block_sort_t::merge_step(</div>
|
||||
<div class="line"><a id="l00663" name="l00663"></a><span class="lineno"> 663</span> tgp_vals + A_st_local,</div>
|
||||
<div class="line"><a id="l00664" name="l00664"></a><span class="lineno"> 664</span> tgp_vals + A_ed_local + B_st_local,</div>
|
||||
<div class="line"><a id="l00665" name="l00665"></a><span class="lineno"> 665</span> tgp_idxs + A_st_local,</div>
|
||||
<div class="line"><a id="l00666" name="l00666"></a><span class="lineno"> 666</span> tgp_idxs + A_ed_local + B_st_local,</div>
|
||||
<div class="line"><a id="l00667" name="l00667"></a><span class="lineno"> 667</span> A_sz_local,</div>
|
||||
<div class="line"><a id="l00668" name="l00668"></a><span class="lineno"> 668</span> B_sz_local,</div>
|
||||
<div class="line"><a id="l00669" name="l00669"></a><span class="lineno"> 669</span> thread_vals,</div>
|
||||
<div class="line"><a id="l00670" name="l00670"></a><span class="lineno"> 670</span> thread_idxs);</div>
|
||||
<div class="line"><a id="l00671" name="l00671"></a><span class="lineno"> 671</span> </div>
|
||||
<div class="line"><a id="l00672" name="l00672"></a><span class="lineno"> 672</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00673" name="l00673"></a><span class="lineno"> 673</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < N_PER_THREAD; ++i) {</div>
|
||||
<div class="line"><a id="l00674" name="l00674"></a><span class="lineno"> 674</span> <span class="keywordtype">int</span> idx = lid.x * N_PER_THREAD;</div>
|
||||
<div class="line"><a id="l00675" name="l00675"></a><span class="lineno"> 675</span> tgp_vals[idx + i] = thread_vals[i];</div>
|
||||
<div class="line"><a id="l00676" name="l00676"></a><span class="lineno"> 676</span> tgp_idxs[idx + i] = thread_idxs[i];</div>
|
||||
<div class="line"><a id="l00677" name="l00677"></a><span class="lineno"> 677</span> }</div>
|
||||
<div class="line"><a id="l00678" name="l00678"></a><span class="lineno"> 678</span> </div>
|
||||
<div class="line"><a id="l00679" name="l00679"></a><span class="lineno"> 679</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00680" name="l00680"></a><span class="lineno"> 680</span> <span class="comment">// Write output</span></div>
|
||||
<div class="line"><a id="l00681" name="l00681"></a><span class="lineno"> 681</span> <span class="keywordtype">int</span> base_idx = tid.x * sort_kernel::N_PER_BLOCK;</div>
|
||||
<div class="line"><a id="l00682" name="l00682"></a><span class="lineno"> 682</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {</div>
|
||||
<div class="line"><a id="l00683" name="l00683"></a><span class="lineno"> 683</span> <span class="keywordtype">int</span> idx = base_idx + i;</div>
|
||||
<div class="line"><a id="l00684" name="l00684"></a><span class="lineno"> 684</span> <span class="keywordflow">if</span> (idx < size_sorted_axis) {</div>
|
||||
<div class="line"><a id="l00685" name="l00685"></a><span class="lineno"> 685</span> dev_vals_out[idx] = tgp_vals[i];</div>
|
||||
<div class="line"><a id="l00686" name="l00686"></a><span class="lineno"> 686</span> dev_idxs_out[idx] = tgp_idxs[i];</div>
|
||||
<div class="line"><a id="l00687" name="l00687"></a><span class="lineno"> 687</span> }</div>
|
||||
<div class="line"><a id="l00688" name="l00688"></a><span class="lineno"> 688</span> }</div>
|
||||
<div class="line"><a id="l00689" name="l00689"></a><span class="lineno"> 689</span>}</div>
|
||||
<div class="line"><a id="l00661" name="l00661"></a><span class="lineno"> 661</span> <span class="keywordtype">int</span> B_st_local = sort_md_local - A_st_local;</div>
|
||||
<div class="line"><a id="l00662" name="l00662"></a><span class="lineno"> 662</span> <span class="keywordtype">int</span> B_ed_local = B_sz;</div>
|
||||
<div class="line"><a id="l00663" name="l00663"></a><span class="lineno"> 663</span> </div>
|
||||
<div class="line"><a id="l00664" name="l00664"></a><span class="lineno"> 664</span> <span class="keywordtype">int</span> A_sz_local = A_ed_local - A_st_local;</div>
|
||||
<div class="line"><a id="l00665" name="l00665"></a><span class="lineno"> 665</span> <span class="keywordtype">int</span> B_sz_local = B_ed_local - B_st_local;</div>
|
||||
<div class="line"><a id="l00666" name="l00666"></a><span class="lineno"> 666</span> </div>
|
||||
<div class="line"><a id="l00667" name="l00667"></a><span class="lineno"> 667</span> <span class="comment">// Do merge</span></div>
|
||||
<div class="line"><a id="l00668" name="l00668"></a><span class="lineno"> 668</span> block_sort_t::merge_step(</div>
|
||||
<div class="line"><a id="l00669" name="l00669"></a><span class="lineno"> 669</span> tgp_vals + A_st_local,</div>
|
||||
<div class="line"><a id="l00670" name="l00670"></a><span class="lineno"> 670</span> tgp_vals + A_ed_local + B_st_local,</div>
|
||||
<div class="line"><a id="l00671" name="l00671"></a><span class="lineno"> 671</span> tgp_idxs + A_st_local,</div>
|
||||
<div class="line"><a id="l00672" name="l00672"></a><span class="lineno"> 672</span> tgp_idxs + A_ed_local + B_st_local,</div>
|
||||
<div class="line"><a id="l00673" name="l00673"></a><span class="lineno"> 673</span> A_sz_local,</div>
|
||||
<div class="line"><a id="l00674" name="l00674"></a><span class="lineno"> 674</span> B_sz_local,</div>
|
||||
<div class="line"><a id="l00675" name="l00675"></a><span class="lineno"> 675</span> thread_vals,</div>
|
||||
<div class="line"><a id="l00676" name="l00676"></a><span class="lineno"> 676</span> thread_idxs);</div>
|
||||
<div class="line"><a id="l00677" name="l00677"></a><span class="lineno"> 677</span> </div>
|
||||
<div class="line"><a id="l00678" name="l00678"></a><span class="lineno"> 678</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00679" name="l00679"></a><span class="lineno"> 679</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i < N_PER_THREAD; ++i) {</div>
|
||||
<div class="line"><a id="l00680" name="l00680"></a><span class="lineno"> 680</span> <span class="keywordtype">int</span> idx = lid.x * N_PER_THREAD;</div>
|
||||
<div class="line"><a id="l00681" name="l00681"></a><span class="lineno"> 681</span> tgp_vals[idx + i] = thread_vals[i];</div>
|
||||
<div class="line"><a id="l00682" name="l00682"></a><span class="lineno"> 682</span> tgp_idxs[idx + i] = thread_idxs[i];</div>
|
||||
<div class="line"><a id="l00683" name="l00683"></a><span class="lineno"> 683</span> }</div>
|
||||
<div class="line"><a id="l00684" name="l00684"></a><span class="lineno"> 684</span> </div>
|
||||
<div class="line"><a id="l00685" name="l00685"></a><span class="lineno"> 685</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
|
||||
<div class="line"><a id="l00686" name="l00686"></a><span class="lineno"> 686</span> <span class="comment">// Write output</span></div>
|
||||
<div class="line"><a id="l00687" name="l00687"></a><span class="lineno"> 687</span> <span class="keywordtype">int</span> base_idx = tid.x * sort_kernel::N_PER_BLOCK;</div>
|
||||
<div class="line"><a id="l00688" name="l00688"></a><span class="lineno"> 688</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {</div>
|
||||
<div class="line"><a id="l00689" name="l00689"></a><span class="lineno"> 689</span> <span class="keywordtype">int</span> idx = base_idx + i;</div>
|
||||
<div class="line"><a id="l00690" name="l00690"></a><span class="lineno"> 690</span> <span class="keywordflow">if</span> (idx < size_sorted_axis) {</div>
|
||||
<div class="line"><a id="l00691" name="l00691"></a><span class="lineno"> 691</span> dev_vals_out[idx] = tgp_vals[i];</div>
|
||||
<div class="line"><a id="l00692" name="l00692"></a><span class="lineno"> 692</span> dev_idxs_out[idx] = tgp_idxs[i];</div>
|
||||
<div class="line"><a id="l00693" name="l00693"></a><span class="lineno"> 693</span> }</div>
|
||||
<div class="line"><a id="l00694" name="l00694"></a><span class="lineno"> 694</span> }</div>
|
||||
<div class="line"><a id="l00695" name="l00695"></a><span class="lineno"> 695</span>}</div>
|
||||
</div>
|
||||
<div class="ttc" id="abackend_2metal_2kernels_2utils_8h_html_a2e49fa7ab8f6348543455c6c45d7e2a9"><div class="ttname"><a href="backend_2metal_2kernels_2utils_8h.html#a2e49fa7ab8f6348543455c6c45d7e2a9">elem_to_loc</a></div><div class="ttdeci">METAL_FUNC stride_t elem_to_loc(uint elem, device const int *shape, device const stride_t *strides, int ndim)</div><div class="ttdef"><b>Definition</b> utils.h:77</div></div>
|
||||
<div class="ttc" id="acommon_2binary_8h_html_a70228731d29946574b238d21fb4b360c"><div class="ttname"><a href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a></div><div class="ttdeci">Op op</div><div class="ttdef"><b>Definition</b> binary.h:141</div></div>
|
||||
@@ -812,10 +818,10 @@ $(function() { codefold.init(0); });
|
||||
<div class="ttc" id="asort_8h_html_a0386011c52d03e60885a31e6fbd903dd"><div class="ttname"><a href="sort_8h.html#a0386011c52d03e60885a31e6fbd903dd">MLX_MTL_CONST</a></div><div class="ttdeci">#define MLX_MTL_CONST</div><div class="ttdef"><b>Definition</b> sort.h:3</div></div>
|
||||
<div class="ttc" id="asort_8h_html_a29229399f51e5c440ffe5c9b99b27598"><div class="ttname"><a href="sort_8h.html#a29229399f51e5c440ffe5c9b99b27598">block_sort_nc</a></div><div class="ttdeci">void block_sort_nc(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &nc_dim, const device int *nc_shape, const device size_t *in_nc_strides, const device size_t *out_nc_strides, uint3 tid, uint3 lid)</div><div class="ttdef"><b>Definition</b> sort.h:338</div></div>
|
||||
<div class="ttc" id="asort_8h_html_a2a0533103661dd378d6bfe949930650a"><div class="ttname"><a href="sort_8h.html#a2a0533103661dd378d6bfe949930650a">mb_block_sort</a></div><div class="ttdeci">void mb_block_sort(const device val_t *inp, device val_t *out_vals, device idx_t *out_idxs, const constant int &size_sorted_axis, const constant int &stride_sorted_axis, const constant int &nc_dim, const device int *nc_shape, const device size_t *nc_strides, uint3 tid, uint3 lid)</div><div class="ttdef"><b>Definition</b> sort.h:481</div></div>
|
||||
<div class="ttc" id="asort_8h_html_a50ae11454e4dfa374a9bd256cdbbf605"><div class="ttname"><a href="sort_8h.html#a50ae11454e4dfa374a9bd256cdbbf605">mb_block_partition</a></div><div class="ttdeci">void mb_block_partition(device idx_t *block_partitions, const device val_t *dev_vals, const device idx_t *dev_idxs, const constant int &size_sorted_axis, const constant int &merge_tiles, uint3 tid, uint3 lid, uint3 tgp_dims)</div><div class="ttdef"><b>Definition</b> sort.h:526</div></div>
|
||||
<div class="ttc" id="asort_8h_html_a32cbe4163b8b0f5cb2c97b256119a4b2"><div class="ttname"><a href="sort_8h.html#a32cbe4163b8b0f5cb2c97b256119a4b2">mb_block_partition</a></div><div class="ttdeci">void mb_block_partition(device idx_t *block_partitions, const device val_t *dev_vals, const device idx_t *dev_idxs, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &n_blocks, uint3 tid, uint3 lid, uint3 tgp_dims)</div><div class="ttdef"><b>Definition</b> sort.h:525</div></div>
|
||||
<div class="ttc" id="asort_8h_html_a6e8c2da4975a8001fd5ddf211a3058b7"><div class="ttname"><a href="sort_8h.html#a6e8c2da4975a8001fd5ddf211a3058b7">thread_swap</a></div><div class="ttdeci">METAL_FUNC void thread_swap(thread T &a, thread T &b)</div><div class="ttdef"><b>Definition</b> sort.h:16</div></div>
|
||||
<div class="ttc" id="asort_8h_html_a93f14092416169c4449141043ac45ffd"><div class="ttname"><a href="sort_8h.html#a93f14092416169c4449141043ac45ffd">block_sort</a></div><div class="ttdeci">void block_sort(const device T *inp, device U *out, const constant int &size_sorted_axis, const constant int &in_stride_sorted_axis, const constant int &out_stride_sorted_axis, const constant int &in_stride_segment_axis, const constant int &out_stride_segment_axis, uint3 tid, uint3 lid)</div><div class="ttdef"><b>Definition</b> sort.h:283</div></div>
|
||||
<div class="ttc" id="asort_8h_html_ab381cd57f344bc7304ab580bfdc78807"><div class="ttname"><a href="sort_8h.html#ab381cd57f344bc7304ab580bfdc78807">mb_block_merge</a></div><div class="ttdeci">void mb_block_merge(const device idx_t *block_partitions, const device val_t *dev_vals_in, const device idx_t *dev_idxs_in, device val_t *dev_vals_out, device idx_t *dev_idxs_out, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &num_tiles, uint3 tid, uint3 lid)</div><div class="ttdef"><b>Definition</b> sort.h:573</div></div>
|
||||
<div class="ttc" id="asort_8h_html_ab381cd57f344bc7304ab580bfdc78807"><div class="ttname"><a href="sort_8h.html#ab381cd57f344bc7304ab580bfdc78807">mb_block_merge</a></div><div class="ttdeci">void mb_block_merge(const device idx_t *block_partitions, const device val_t *dev_vals_in, const device idx_t *dev_idxs_in, device val_t *dev_vals_out, device idx_t *dev_idxs_out, const constant int &size_sorted_axis, const constant int &merge_tiles, const constant int &num_tiles, uint3 tid, uint3 lid)</div><div class="ttdef"><b>Definition</b> sort.h:579</div></div>
|
||||
<div class="ttc" id="asort_8h_html_aca8b6f36c9024b8406fe545765316dc0"><div class="ttname"><a href="sort_8h.html#aca8b6f36c9024b8406fe545765316dc0">zero_helper</a></div><div class="ttdeci">constant constexpr const int zero_helper</div><div class="ttdef"><b>Definition</b> sort.h:330</div></div>
|
||||
<div class="ttc" id="asort_8h_html_ad34b622323cebef136669fedd7229515"><div class="ttname"><a href="sort_8h.html#ad34b622323cebef136669fedd7229515">MLX_MTL_LOOP_UNROLL</a></div><div class="ttdeci">#define MLX_MTL_LOOP_UNROLL</div><div class="ttdef"><b>Definition</b> sort.h:4</div></div>
|
||||
<div class="ttc" id="astruct_block_merge_sort_html"><div class="ttname"><a href="struct_block_merge_sort.html">BlockMergeSort</a></div><div class="ttdef"><b>Definition</b> sort.h:67</div></div>
|
||||
|
Reference in New Issue
Block a user