mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 18:39:45 +08:00
spelling: partition
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
This commit is contained in:
parent
c18cb065b4
commit
979ae32b98
@ -592,7 +592,7 @@ template <
|
||||
bool ARG_SORT,
|
||||
short BLOCK_THREADS,
|
||||
short N_PER_THREAD>
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partiton(
|
||||
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partition(
|
||||
device idx_t* block_partitions [[buffer(0)]],
|
||||
const device val_t* dev_vals [[buffer(1)]],
|
||||
const device idx_t* dev_idxs [[buffer(2)]],
|
||||
@ -777,8 +777,8 @@ template <
|
||||
const device size_t* nc_strides [[buffer(7)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]]); \
|
||||
template [[host_name("mb_block_partiton_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void mb_block_partiton<vtype, itype, arg_sort, bn, tn>( \
|
||||
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
|
||||
[[kernel]] void mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
|
||||
device itype* block_partitions [[buffer(0)]], \
|
||||
const device vtype* dev_vals [[buffer(1)]], \
|
||||
const device itype* dev_idxs [[buffer(2)]], \
|
||||
|
@ -165,10 +165,10 @@ void multi_block_sort(
|
||||
dev_idxs_out = ping ? dev_idxs_0 : dev_idxs_1;
|
||||
ping = !ping;
|
||||
|
||||
// Do partiton
|
||||
// Do partition
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "mb_block_partiton_" << type_to_name(dev_vals_in) << "_"
|
||||
kname << "mb_block_partition_" << type_to_name(dev_vals_in) << "_"
|
||||
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
|
@ -2253,7 +2253,7 @@ void init_ops(py::module_& m) {
|
||||
will be of elements less or equal to the element at the ``kth``
|
||||
index and all indices after will be of elements greater or equal
|
||||
to the element at the ``kth`` index.
|
||||
axis (int or None, optional): Optional axis to partiton over.
|
||||
axis (int or None, optional): Optional axis to partition over.
|
||||
If ``None``, this partitions over the flattened array.
|
||||
If unspecified, it defaults to ``-1``.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user