mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
reduce binary size (#1952)
This commit is contained in:
@@ -105,15 +105,11 @@ struct StridedIterator {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void sort(const array& in, array& out, int axis, Stream stream) {
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype, stream);
|
||||
|
||||
void sort(array& out, int axis) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
|
||||
size_t n_rows = in_size / in.shape(axis);
|
||||
axis = axis < 0 ? axis + out.ndim() : axis;
|
||||
size_t in_size = out.size();
|
||||
size_t n_rows = in_size / out.shape(axis);
|
||||
|
||||
auto remaining_shape = out.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
@@ -127,30 +123,20 @@ void sort(const array& in, array& out, int axis, Stream stream) {
|
||||
// Perform sorting in place
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out_ptr = out.data<T>(),
|
||||
src_it = std::move(src_it),
|
||||
n_rows,
|
||||
axis_size,
|
||||
axis_stride]() mutable {
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
auto out_ptr = out.data<T>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
|
||||
std::stable_sort(st, ed);
|
||||
src_it.step();
|
||||
}
|
||||
});
|
||||
std::stable_sort(st, ed);
|
||||
src_it.step();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = uint32_t>
|
||||
void argsort(const array& in, array& out, int axis, Stream stream) {
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
void argsort(const array& in, array& out, int axis) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
@@ -176,99 +162,69 @@ void argsort(const array& in, array& out, int axis, Stream stream) {
|
||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
||||
ContiguousIterator out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(out);
|
||||
encoder.dispatch([in_ptr = in.data<T>(),
|
||||
out_ptr = out.data<IdxT>(),
|
||||
in_it = std::move(in_it),
|
||||
out_it = std::move(out_it),
|
||||
n_rows,
|
||||
axis_size,
|
||||
in_stride,
|
||||
out_stride]() mutable {
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
auto in_ptr = in.data<T>();
|
||||
auto out_ptr = out.data<IdxT>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
|
||||
in_it.step();
|
||||
out_it.step();
|
||||
in_it.step();
|
||||
out_it.step();
|
||||
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
||||
|
||||
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
});
|
||||
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void partition(const array& in, array& out, int axis, int kth, Stream stream) {
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype, stream);
|
||||
|
||||
void partition(array& out, int axis, int kth) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
|
||||
size_t n_rows = in_size / in.shape(axis);
|
||||
axis = axis < 0 ? axis + out.ndim() : axis;
|
||||
size_t in_size = out.size();
|
||||
size_t n_rows = in_size / out.shape(axis);
|
||||
|
||||
auto remaining_shape = in.shape();
|
||||
auto remaining_shape = out.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
|
||||
auto remaining_strides = in.strides();
|
||||
auto remaining_strides = out.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
|
||||
auto axis_stride = in.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
auto axis_stride = out.strides()[axis];
|
||||
int axis_size = out.shape(axis);
|
||||
|
||||
kth = kth < 0 ? kth + axis_size : kth;
|
||||
|
||||
// Perform partition in place
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out_ptr = out.data<T>(),
|
||||
src_it = std::move(src_it),
|
||||
n_rows,
|
||||
axis_size,
|
||||
axis_stride,
|
||||
kth]() mutable {
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
src_it.step();
|
||||
auto out_ptr = out.data<T>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
src_it.step();
|
||||
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator md(data_ptr, axis_stride, kth);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator md(data_ptr, axis_stride, kth);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
|
||||
std::nth_element(st, md, ed);
|
||||
}
|
||||
});
|
||||
std::nth_element(st, md, ed);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = uint32_t>
|
||||
void argpartition(
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis,
|
||||
int kth,
|
||||
Stream stream) {
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
@@ -297,42 +253,32 @@ void argpartition(
|
||||
ContiguousIterator out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(out);
|
||||
encoder.dispatch([in_ptr = in.data<T>(),
|
||||
out_ptr = out.data<IdxT>(),
|
||||
in_it = std::move(in_it),
|
||||
out_it = std::move(out_it),
|
||||
n_rows,
|
||||
axis_size,
|
||||
in_stride,
|
||||
out_stride,
|
||||
kth]() mutable {
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
in_it.step();
|
||||
out_it.step();
|
||||
auto in_ptr = in.data<T>();
|
||||
auto out_ptr = out.data<IdxT>();
|
||||
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
in_it.step();
|
||||
out_it.step();
|
||||
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, out_stride, 0);
|
||||
StridedIterator md(idx_ptr, out_stride, kth);
|
||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
|
||||
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
});
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, out_stride, 0);
|
||||
StridedIterator md(idx_ptr, out_stride, kth);
|
||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
||||
|
||||
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -341,144 +287,184 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return argsort<bool>(in, out, axis_, stream());
|
||||
case uint8:
|
||||
return argsort<uint8_t>(in, out, axis_, stream());
|
||||
case uint16:
|
||||
return argsort<uint16_t>(in, out, axis_, stream());
|
||||
case uint32:
|
||||
return argsort<uint32_t>(in, out, axis_, stream());
|
||||
case uint64:
|
||||
return argsort<uint64_t>(in, out, axis_, stream());
|
||||
case int8:
|
||||
return argsort<int8_t>(in, out, axis_, stream());
|
||||
case int16:
|
||||
return argsort<int16_t>(in, out, axis_, stream());
|
||||
case int32:
|
||||
return argsort<int32_t>(in, out, axis_, stream());
|
||||
case int64:
|
||||
return argsort<int64_t>(in, out, axis_, stream());
|
||||
case float32:
|
||||
return argsort<float>(in, out, axis_, stream());
|
||||
case float64:
|
||||
return argsort<double>(in, out, axis_, stream());
|
||||
case float16:
|
||||
return argsort<float16_t>(in, out, axis_, stream());
|
||||
case bfloat16:
|
||||
return argsort<bfloat16_t>(in, out, axis_, stream());
|
||||
case complex64:
|
||||
return argsort<complex64_t>(in, out, axis_, stream());
|
||||
}
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
axis_ = axis_]() mutable {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return argsort<bool>(in, out, axis_);
|
||||
case uint8:
|
||||
return argsort<uint8_t>(in, out, axis_);
|
||||
case uint16:
|
||||
return argsort<uint16_t>(in, out, axis_);
|
||||
case uint32:
|
||||
return argsort<uint32_t>(in, out, axis_);
|
||||
case uint64:
|
||||
return argsort<uint64_t>(in, out, axis_);
|
||||
case int8:
|
||||
return argsort<int8_t>(in, out, axis_);
|
||||
case int16:
|
||||
return argsort<int16_t>(in, out, axis_);
|
||||
case int32:
|
||||
return argsort<int32_t>(in, out, axis_);
|
||||
case int64:
|
||||
return argsort<int64_t>(in, out, axis_);
|
||||
case float32:
|
||||
return argsort<float>(in, out, axis_);
|
||||
case float64:
|
||||
return argsort<double>(in, out, axis_);
|
||||
case float16:
|
||||
return argsort<float16_t>(in, out, axis_);
|
||||
case bfloat16:
|
||||
return argsort<bfloat16_t>(in, out, axis_);
|
||||
case complex64:
|
||||
return argsort<complex64_t>(in, out, axis_);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return sort<bool>(in, out, axis_, stream());
|
||||
case uint8:
|
||||
return sort<uint8_t>(in, out, axis_, stream());
|
||||
case uint16:
|
||||
return sort<uint16_t>(in, out, axis_, stream());
|
||||
case uint32:
|
||||
return sort<uint32_t>(in, out, axis_, stream());
|
||||
case uint64:
|
||||
return sort<uint64_t>(in, out, axis_, stream());
|
||||
case int8:
|
||||
return sort<int8_t>(in, out, axis_, stream());
|
||||
case int16:
|
||||
return sort<int16_t>(in, out, axis_, stream());
|
||||
case int32:
|
||||
return sort<int32_t>(in, out, axis_, stream());
|
||||
case int64:
|
||||
return sort<int64_t>(in, out, axis_, stream());
|
||||
case float32:
|
||||
return sort<float>(in, out, axis_, stream());
|
||||
case float64:
|
||||
return sort<double>(in, out, axis_, stream());
|
||||
case float16:
|
||||
return sort<float16_t>(in, out, axis_, stream());
|
||||
case bfloat16:
|
||||
return sort<bfloat16_t>(in, out, axis_, stream());
|
||||
case complex64:
|
||||
return sort<complex64_t>(in, out, axis_, stream());
|
||||
}
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype, stream());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch(
|
||||
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
return sort<bool>(out, axis_);
|
||||
case uint8:
|
||||
return sort<uint8_t>(out, axis_);
|
||||
case uint16:
|
||||
return sort<uint16_t>(out, axis_);
|
||||
case uint32:
|
||||
return sort<uint32_t>(out, axis_);
|
||||
case uint64:
|
||||
return sort<uint64_t>(out, axis_);
|
||||
case int8:
|
||||
return sort<int8_t>(out, axis_);
|
||||
case int16:
|
||||
return sort<int16_t>(out, axis_);
|
||||
case int32:
|
||||
return sort<int32_t>(out, axis_);
|
||||
case int64:
|
||||
return sort<int64_t>(out, axis_);
|
||||
case float32:
|
||||
return sort<float>(out, axis_);
|
||||
case float64:
|
||||
return sort<double>(out, axis_);
|
||||
case float16:
|
||||
return sort<float16_t>(out, axis_);
|
||||
case bfloat16:
|
||||
return sort<bfloat16_t>(out, axis_);
|
||||
case complex64:
|
||||
return sort<complex64_t>(out, axis_);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return argpartition<bool>(in, out, axis_, kth_, stream());
|
||||
case uint8:
|
||||
return argpartition<uint8_t>(in, out, axis_, kth_, stream());
|
||||
case uint16:
|
||||
return argpartition<uint16_t>(in, out, axis_, kth_, stream());
|
||||
case uint32:
|
||||
return argpartition<uint32_t>(in, out, axis_, kth_, stream());
|
||||
case uint64:
|
||||
return argpartition<uint64_t>(in, out, axis_, kth_, stream());
|
||||
case int8:
|
||||
return argpartition<int8_t>(in, out, axis_, kth_, stream());
|
||||
case int16:
|
||||
return argpartition<int16_t>(in, out, axis_, kth_, stream());
|
||||
case int32:
|
||||
return argpartition<int32_t>(in, out, axis_, kth_, stream());
|
||||
case int64:
|
||||
return argpartition<int64_t>(in, out, axis_, kth_, stream());
|
||||
case float32:
|
||||
return argpartition<float>(in, out, axis_, kth_, stream());
|
||||
case float64:
|
||||
return argpartition<double>(in, out, axis_, kth_, stream());
|
||||
case float16:
|
||||
return argpartition<float16_t>(in, out, axis_, kth_, stream());
|
||||
case bfloat16:
|
||||
return argpartition<bfloat16_t>(in, out, axis_, kth_, stream());
|
||||
case complex64:
|
||||
return argpartition<complex64_t>(in, out, axis_, kth_, stream());
|
||||
}
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
axis_ = axis_,
|
||||
kth_ = kth_]() mutable {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return argpartition<bool>(in, out, axis_, kth_);
|
||||
case uint8:
|
||||
return argpartition<uint8_t>(in, out, axis_, kth_);
|
||||
case uint16:
|
||||
return argpartition<uint16_t>(in, out, axis_, kth_);
|
||||
case uint32:
|
||||
return argpartition<uint32_t>(in, out, axis_, kth_);
|
||||
case uint64:
|
||||
return argpartition<uint64_t>(in, out, axis_, kth_);
|
||||
case int8:
|
||||
return argpartition<int8_t>(in, out, axis_, kth_);
|
||||
case int16:
|
||||
return argpartition<int16_t>(in, out, axis_, kth_);
|
||||
case int32:
|
||||
return argpartition<int32_t>(in, out, axis_, kth_);
|
||||
case int64:
|
||||
return argpartition<int64_t>(in, out, axis_, kth_);
|
||||
case float32:
|
||||
return argpartition<float>(in, out, axis_, kth_);
|
||||
case float64:
|
||||
return argpartition<double>(in, out, axis_, kth_);
|
||||
case float16:
|
||||
return argpartition<float16_t>(in, out, axis_, kth_);
|
||||
case bfloat16:
|
||||
return argpartition<bfloat16_t>(in, out, axis_, kth_);
|
||||
case complex64:
|
||||
return argpartition<complex64_t>(in, out, axis_, kth_);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return partition<bool>(in, out, axis_, kth_, stream());
|
||||
case uint8:
|
||||
return partition<uint8_t>(in, out, axis_, kth_, stream());
|
||||
case uint16:
|
||||
return partition<uint16_t>(in, out, axis_, kth_, stream());
|
||||
case uint32:
|
||||
return partition<uint32_t>(in, out, axis_, kth_, stream());
|
||||
case uint64:
|
||||
return partition<uint64_t>(in, out, axis_, kth_, stream());
|
||||
case int8:
|
||||
return partition<int8_t>(in, out, axis_, kth_, stream());
|
||||
case int16:
|
||||
return partition<int16_t>(in, out, axis_, kth_, stream());
|
||||
case int32:
|
||||
return partition<int32_t>(in, out, axis_, kth_, stream());
|
||||
case int64:
|
||||
return partition<int64_t>(in, out, axis_, kth_, stream());
|
||||
case float32:
|
||||
return partition<float>(in, out, axis_, kth_, stream());
|
||||
case float64:
|
||||
return partition<double>(in, out, axis_, kth_, stream());
|
||||
case float16:
|
||||
return partition<float16_t>(in, out, axis_, kth_, stream());
|
||||
case bfloat16:
|
||||
return partition<bfloat16_t>(in, out, axis_, kth_, stream());
|
||||
case complex64:
|
||||
return partition<complex64_t>(in, out, axis_, kth_, stream());
|
||||
}
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype, stream());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
axis_ = axis_,
|
||||
kth_ = kth_]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
return partition<bool>(out, axis_, kth_);
|
||||
case uint8:
|
||||
return partition<uint8_t>(out, axis_, kth_);
|
||||
case uint16:
|
||||
return partition<uint16_t>(out, axis_, kth_);
|
||||
case uint32:
|
||||
return partition<uint32_t>(out, axis_, kth_);
|
||||
case uint64:
|
||||
return partition<uint64_t>(out, axis_, kth_);
|
||||
case int8:
|
||||
return partition<int8_t>(out, axis_, kth_);
|
||||
case int16:
|
||||
return partition<int16_t>(out, axis_, kth_);
|
||||
case int32:
|
||||
return partition<int32_t>(out, axis_, kth_);
|
||||
case int64:
|
||||
return partition<int64_t>(out, axis_, kth_);
|
||||
case float32:
|
||||
return partition<float>(out, axis_, kth_);
|
||||
case float64:
|
||||
return partition<double>(out, axis_, kth_);
|
||||
case float16:
|
||||
return partition<float16_t>(out, axis_, kth_);
|
||||
case bfloat16:
|
||||
return partition<bfloat16_t>(out, axis_, kth_);
|
||||
case complex64:
|
||||
return partition<complex64_t>(out, axis_, kth_);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user