mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 11:38:06 +08:00
spelling: partition
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
This commit is contained in:
parent
c18cb065b4
commit
979ae32b98
@ -592,7 +592,7 @@ template <
|
|||||||
bool ARG_SORT,
|
bool ARG_SORT,
|
||||||
short BLOCK_THREADS,
|
short BLOCK_THREADS,
|
||||||
short N_PER_THREAD>
|
short N_PER_THREAD>
|
||||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partiton(
|
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partition(
|
||||||
device idx_t* block_partitions [[buffer(0)]],
|
device idx_t* block_partitions [[buffer(0)]],
|
||||||
const device val_t* dev_vals [[buffer(1)]],
|
const device val_t* dev_vals [[buffer(1)]],
|
||||||
const device idx_t* dev_idxs [[buffer(2)]],
|
const device idx_t* dev_idxs [[buffer(2)]],
|
||||||
@ -777,8 +777,8 @@ template <
|
|||||||
const device size_t* nc_strides [[buffer(7)]], \
|
const device size_t* nc_strides [[buffer(7)]], \
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||||
template [[host_name("mb_block_partiton_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||||
[[kernel]] void mb_block_partiton<vtype, itype, arg_sort, bn, tn>( \
|
[[kernel]] void mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
|
||||||
device itype* block_partitions [[buffer(0)]], \
|
device itype* block_partitions [[buffer(0)]], \
|
||||||
const device vtype* dev_vals [[buffer(1)]], \
|
const device vtype* dev_vals [[buffer(1)]], \
|
||||||
const device itype* dev_idxs [[buffer(2)]], \
|
const device itype* dev_idxs [[buffer(2)]], \
|
||||||
|
@ -165,10 +165,10 @@ void multi_block_sort(
|
|||||||
dev_idxs_out = ping ? dev_idxs_0 : dev_idxs_1;
|
dev_idxs_out = ping ? dev_idxs_0 : dev_idxs_1;
|
||||||
ping = !ping;
|
ping = !ping;
|
||||||
|
|
||||||
// Do partiton
|
// Do partition
|
||||||
{
|
{
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "mb_block_partiton_" << type_to_name(dev_vals_in) << "_"
|
kname << "mb_block_partition_" << type_to_name(dev_vals_in) << "_"
|
||||||
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
||||||
|
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
@ -2253,7 +2253,7 @@ void init_ops(py::module_& m) {
|
|||||||
will be of elements less or equal to the element at the ``kth``
|
will be of elements less or equal to the element at the ``kth``
|
||||||
index and all indices after will be of elements greater or equal
|
index and all indices after will be of elements greater or equal
|
||||||
to the element at the ``kth`` index.
|
to the element at the ``kth`` index.
|
||||||
axis (int or None, optional): Optional axis to partiton over.
|
axis (int or None, optional): Optional axis to partition over.
|
||||||
If ``None``, this partitions over the flattened array.
|
If ``None``, this partitions over the flattened array.
|
||||||
If unspecified, it defaults to ``-1``.
|
If unspecified, it defaults to ``-1``.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user