// Copyright © 2023 Apple Inc. #include #include #include #include #include "mlx/backend/common/copy.h" #include "mlx/backend/common/utils.h" #include "mlx/primitives.h" namespace mlx::core { namespace { template struct StridedIterator { using iterator_category = std::random_access_iterator_tag; using difference_type = int32_t; using value_type = T; using reference = value_type&; using pointer = value_type*; // Constructors StridedIterator() = default; explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0) : ptr_(ptr + offset * stride), stride_(stride) {} explicit StridedIterator(array& arr, int axis, difference_type offset = 0) : StridedIterator(arr.data(), arr.strides()[axis], offset) {} // Accessors reference operator*() const { return ptr_[0]; } reference operator[](difference_type idx) const { return ptr_[idx * stride_]; } // Comparisons bool operator==(const StridedIterator& other) const { return ptr_ == other.ptr_ && stride_ == other.stride_; } bool operator!=(const StridedIterator& other) const { return ptr_ != other.ptr_; } bool operator<(const StridedIterator& other) const { return ptr_ < other.ptr_; } bool operator>(const StridedIterator& other) const { return ptr_ > other.ptr_; } bool operator<=(const StridedIterator& other) const { return ptr_ <= other.ptr_; } bool operator>=(const StridedIterator& other) const { return ptr_ >= other.ptr_; } difference_type operator-(const StridedIterator& other) const { return (ptr_ - other.ptr_) / stride_; } // Moving StridedIterator& operator++() { ptr_ += stride_; return *this; } StridedIterator& operator--() { ptr_ -= stride_; return *this; } StridedIterator& operator+=(difference_type diff) { ptr_ += diff * stride_; return *this; } StridedIterator& operator-=(difference_type diff) { ptr_ -= diff * stride_; return *this; } StridedIterator operator+(difference_type diff) { return StridedIterator(ptr_, stride_, diff); } StridedIterator operator-(difference_type diff) { return StridedIterator(ptr_, stride_, -diff); } private: int64_t stride_; T* ptr_; }; template void sort(const array& in, array& out, int axis) { // Copy input to output CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; copy(in, out, ctype); // Get axis, shape and stride info axis = axis < 0 ? axis + in.ndim() : axis; size_t in_size = in.flags().contiguous ? in.data_size() : in.size(); size_t n_rows = in_size / in.shape(axis); auto remaining_shape = out.shape(); remaining_shape.erase(remaining_shape.begin() + axis); auto remaining_strides = out.strides(); remaining_strides.erase(remaining_strides.begin() + axis); auto axis_stride = out.strides()[axis]; auto axis_size = out.shape(axis); // Perform sorting in place ContiguousIterator src_it( remaining_shape, remaining_strides, remaining_shape.size()); for (int i = 0; i < n_rows; i++) { T* data_ptr = out.data() + src_it.loc; StridedIterator st(data_ptr, axis_stride, 0); StridedIterator ed(data_ptr, axis_stride, axis_size); std::stable_sort(st, ed); src_it.step(); } } template void argsort(const array& in, array& out, int axis) { // Allocate output out.set_data(allocator::malloc_or_wait(out.nbytes())); // Get axis, shape and stride info axis = axis < 0 ? axis + in.ndim() : axis; size_t n_rows = in.size() / in.shape(axis); auto in_remaining_shape = in.shape(); in_remaining_shape.erase(in_remaining_shape.begin() + axis); auto in_remaining_strides = in.strides(); in_remaining_strides.erase(in_remaining_strides.begin() + axis); auto out_remaining_shape = out.shape(); out_remaining_shape.erase(out_remaining_shape.begin() + axis); auto out_remaining_strides = out.strides(); out_remaining_strides.erase(out_remaining_strides.begin() + axis); auto in_stride = in.strides()[axis]; auto out_stride = out.strides()[axis]; auto axis_size = in.shape(axis); // Perform sorting ContiguousIterator in_it( in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); ContiguousIterator out_it( out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); for (int i = 0; i < n_rows; i++) { const T* data_ptr = in.data() + in_it.loc; IdxT* idx_ptr = out.data() + out_it.loc; in_it.step(); out_it.step(); StridedIterator st_(idx_ptr, out_stride, 0); StridedIterator ed_(idx_ptr, out_stride, axis_size); // Initialize with iota std::iota(st_, ed_, IdxT(0)); // Sort according to vals StridedIterator st(idx_ptr, out_stride, 0); StridedIterator ed(idx_ptr, out_stride, axis_size); std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) { auto v1 = data_ptr[a * in_stride]; auto v2 = data_ptr[b * in_stride]; return v1 < v2 || (v1 == v2 && a < b); }); } } template void partition(const array& in, array& out, int axis, int kth) { // Copy input to output CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; copy(in, out, ctype); // Get axis, shape and stride info axis = axis < 0 ? axis + in.ndim() : axis; size_t in_size = in.flags().contiguous ? in.data_size() : in.size(); size_t n_rows = in_size / in.shape(axis); auto remaining_shape = in.shape(); remaining_shape.erase(remaining_shape.begin() + axis); auto remaining_strides = in.strides(); remaining_strides.erase(remaining_strides.begin() + axis); auto axis_stride = in.strides()[axis]; int axis_size = in.shape(axis); kth = kth < 0 ? kth + axis_size : kth; // Perform partition in place ContiguousIterator src_it( remaining_shape, remaining_strides, remaining_shape.size()); for (int i = 0; i < n_rows; i++) { T* data_ptr = out.data() + src_it.loc; src_it.step(); StridedIterator st(data_ptr, axis_stride, 0); StridedIterator md(data_ptr, axis_stride, kth); StridedIterator ed(data_ptr, axis_stride, axis_size); std::nth_element(st, md, ed); } } template void argpartition(const array& in, array& out, int axis, int kth) { // Allocate output out.set_data(allocator::malloc_or_wait(out.nbytes())); // Get axis, shape and stride info axis = axis < 0 ? axis + in.ndim() : axis; size_t n_rows = in.size() / in.shape(axis); auto in_remaining_shape = in.shape(); in_remaining_shape.erase(in_remaining_shape.begin() + axis); auto in_remaining_strides = in.strides(); in_remaining_strides.erase(in_remaining_strides.begin() + axis); auto out_remaining_shape = out.shape(); out_remaining_shape.erase(out_remaining_shape.begin() + axis); auto out_remaining_strides = out.strides(); out_remaining_strides.erase(out_remaining_strides.begin() + axis); auto in_stride = in.strides()[axis]; auto out_stride = out.strides()[axis]; auto axis_size = in.shape(axis); kth = kth < 0 ? kth + axis_size : kth; // Perform partition ContiguousIterator in_it( in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); ContiguousIterator out_it( out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); for (int i = 0; i < n_rows; i++) { const T* data_ptr = in.data() + in_it.loc; IdxT* idx_ptr = out.data() + out_it.loc; in_it.step(); out_it.step(); StridedIterator st_(idx_ptr, out_stride, 0); StridedIterator ed_(idx_ptr, out_stride, axis_size); // Initialize with iota std::iota(st_, ed_, IdxT(0)); // Sort according to vals StridedIterator st(idx_ptr, out_stride, 0); StridedIterator md(idx_ptr, out_stride, kth); StridedIterator ed(idx_ptr, out_stride, axis_size); std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) { auto v1 = data_ptr[a * in_stride]; auto v2 = data_ptr[b * in_stride]; return v1 < v2 || (v1 == v2 && a < b); }); } } } // namespace void ArgSort::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; switch (in.dtype()) { case bool_: return argsort(in, out, axis_); case uint8: return argsort(in, out, axis_); case uint16: return argsort(in, out, axis_); case uint32: return argsort(in, out, axis_); case uint64: return argsort(in, out, axis_); case int8: return argsort(in, out, axis_); case int16: return argsort(in, out, axis_); case int32: return argsort(in, out, axis_); case int64: return argsort(in, out, axis_); case float32: return argsort(in, out, axis_); case float16: return argsort(in, out, axis_); case bfloat16: return argsort(in, out, axis_); case complex64: return argsort(in, out, axis_); } } void Sort::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; switch (in.dtype()) { case bool_: return sort(in, out, axis_); case uint8: return sort(in, out, axis_); case uint16: return sort(in, out, axis_); case uint32: return sort(in, out, axis_); case uint64: return sort(in, out, axis_); case int8: return sort(in, out, axis_); case int16: return sort(in, out, axis_); case int32: return sort(in, out, axis_); case int64: return sort(in, out, axis_); case float32: return sort(in, out, axis_); case float16: return sort(in, out, axis_); case bfloat16: return sort(in, out, axis_); case complex64: return sort(in, out, axis_); } } void ArgPartition::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; switch (in.dtype()) { case bool_: return argpartition(in, out, axis_, kth_); case uint8: return argpartition(in, out, axis_, kth_); case uint16: return argpartition(in, out, axis_, kth_); case uint32: return argpartition(in, out, axis_, kth_); case uint64: return argpartition(in, out, axis_, kth_); case int8: return argpartition(in, out, axis_, kth_); case int16: return argpartition(in, out, axis_, kth_); case int32: return argpartition(in, out, axis_, kth_); case int64: return argpartition(in, out, axis_, kth_); case float32: return argpartition(in, out, axis_, kth_); case float16: return argpartition(in, out, axis_, kth_); case bfloat16: return argpartition(in, out, axis_, kth_); case complex64: return argpartition(in, out, axis_, kth_); } } void Partition::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; switch (in.dtype()) { case bool_: return partition(in, out, axis_, kth_); case uint8: return partition(in, out, axis_, kth_); case uint16: return partition(in, out, axis_, kth_); case uint32: return partition(in, out, axis_, kth_); case uint64: return partition(in, out, axis_, kth_); case int8: return partition(in, out, axis_, kth_); case int16: return partition(in, out, axis_, kth_); case int32: return partition(in, out, axis_, kth_); case int64: return partition(in, out, axis_, kth_); case float32: return partition(in, out, axis_, kth_); case float16: return partition(in, out, axis_, kth_); case bfloat16: return partition(in, out, axis_, kth_); case complex64: return partition(in, out, axis_, kth_); } } } // namespace mlx::core