diff --git a/mlx/backend/metal/kernels/sort.metal b/mlx/backend/metal/kernels/sort.metal index 3aa54de3e..50b1cfbb6 100644 --- a/mlx/backend/metal/kernels/sort.metal +++ b/mlx/backend/metal/kernels/sort.metal @@ -592,7 +592,7 @@ template < bool ARG_SORT, short BLOCK_THREADS, short N_PER_THREAD> -[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partiton( +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partition( device idx_t* block_partitions [[buffer(0)]], const device val_t* dev_vals [[buffer(1)]], const device idx_t* dev_idxs [[buffer(2)]], @@ -777,8 +777,8 @@ template < const device size_t* nc_strides [[buffer(7)]], \ uint3 tid [[threadgroup_position_in_grid]], \ uint3 lid [[thread_position_in_threadgroup]]); \ - template [[host_name("mb_block_partiton_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \ - [[kernel]] void mb_block_partiton( \ + template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \ + [[kernel]] void mb_block_partition( \ device itype* block_partitions [[buffer(0)]], \ const device vtype* dev_vals [[buffer(1)]], \ const device itype* dev_idxs [[buffer(2)]], \ diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index befbf2d81..9eb9960e0 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -165,10 +165,10 @@ void multi_block_sort( dev_idxs_out = ping ? dev_idxs_0 : dev_idxs_1; ping = !ping; - // Do partiton + // Do partition { std::ostringstream kname; - kname << "mb_block_partiton_" << type_to_name(dev_vals_in) << "_" + kname << "mb_block_partition_" << type_to_name(dev_vals_in) << "_" << type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn; auto kernel = d.get_kernel(kname.str()); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index db27a14a1..cab5f28c6 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2253,7 +2253,7 @@ void init_ops(py::module_& m) { will be of elements less or equal to the element at the ``kth`` index and all indices after will be of elements greater or equal to the element at the ``kth`` index. - axis (int or None, optional): Optional axis to partiton over. + axis (int or None, optional): Optional axis to partition over. If ``None``, this partitions over the flattened array. If unspecified, it defaults to ``-1``.