reduce binary size (#1952)

This commit is contained in:
Awni Hannun
2025-03-11 06:30:44 -07:00
committed by GitHub
parent 117e1355a2
commit 736a340478
16 changed files with 2145 additions and 2386 deletions

View File

@@ -105,15 +105,11 @@ struct StridedIterator {
};
template <typename T>
void sort(const array& in, array& out, int axis, Stream stream) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream);
void sort(array& out, int axis) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
size_t n_rows = in_size / in.shape(axis);
axis = axis < 0 ? axis + out.ndim() : axis;
size_t in_size = out.size();
size_t n_rows = in_size / out.shape(axis);
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
@@ -127,30 +123,20 @@ void sort(const array& in, array& out, int axis, Stream stream) {
// Perform sorting in place
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<T>(),
src_it = std::move(src_it),
n_rows,
axis_size,
axis_stride]() mutable {
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
auto out_ptr = out.data<T>();
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size);
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed);
src_it.step();
}
});
std::stable_sort(st, ed);
src_it.step();
}
}
template <typename T, typename IdxT = uint32_t>
void argsort(const array& in, array& out, int axis, Stream stream) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
void argsort(const array& in, array& out, int axis) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
@@ -176,99 +162,69 @@ void argsort(const array& in, array& out, int axis, Stream stream) {
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<IdxT>(),
in_it = std::move(in_it),
out_it = std::move(out_it),
n_rows,
axis_size,
in_stride,
out_stride]() mutable {
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
auto in_ptr = in.data<T>();
auto out_ptr = out.data<IdxT>();
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step();
out_it.step();
in_it.step();
out_it.step();
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator ed(idx_ptr, out_stride, axis_size);
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator ed(idx_ptr, out_stride, axis_size);
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
});
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
}
template <typename T>
void partition(const array& in, array& out, int axis, int kth, Stream stream) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream);
void partition(array& out, int axis, int kth) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
size_t n_rows = in_size / in.shape(axis);
axis = axis < 0 ? axis + out.ndim() : axis;
size_t in_size = out.size();
size_t n_rows = in_size / out.shape(axis);
auto remaining_shape = in.shape();
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = in.strides();
auto remaining_strides = out.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
auto axis_stride = in.strides()[axis];
int axis_size = in.shape(axis);
auto axis_stride = out.strides()[axis];
int axis_size = out.shape(axis);
kth = kth < 0 ? kth + axis_size : kth;
// Perform partition in place
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<T>(),
src_it = std::move(src_it),
n_rows,
axis_size,
axis_stride,
kth]() mutable {
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
src_it.step();
auto out_ptr = out.data<T>();
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
src_it.step();
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size);
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::nth_element(st, md, ed);
}
});
std::nth_element(st, md, ed);
}
}
template <typename T, typename IdxT = uint32_t>
void argpartition(
const array& in,
array& out,
int axis,
int kth,
Stream stream) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
void argpartition(const array& in, array& out, int axis, int kth) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
@@ -297,42 +253,32 @@ void argpartition(
ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in_ptr = in.data<T>(),
out_ptr = out.data<IdxT>(),
in_it = std::move(in_it),
out_it = std::move(out_it),
n_rows,
axis_size,
in_stride,
out_stride,
kth]() mutable {
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step();
out_it.step();
auto in_ptr = in.data<T>();
auto out_ptr = out.data<IdxT>();
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step();
out_it.step();
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator md(idx_ptr, out_stride, kth);
StridedIterator ed(idx_ptr, out_stride, axis_size);
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
});
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator md(idx_ptr, out_stride, kth);
StridedIterator ed(idx_ptr, out_stride, axis_size);
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
}
} // namespace
@@ -341,144 +287,184 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
case bool_:
return argsort<bool>(in, out, axis_, stream());
case uint8:
return argsort<uint8_t>(in, out, axis_, stream());
case uint16:
return argsort<uint16_t>(in, out, axis_, stream());
case uint32:
return argsort<uint32_t>(in, out, axis_, stream());
case uint64:
return argsort<uint64_t>(in, out, axis_, stream());
case int8:
return argsort<int8_t>(in, out, axis_, stream());
case int16:
return argsort<int16_t>(in, out, axis_, stream());
case int32:
return argsort<int32_t>(in, out, axis_, stream());
case int64:
return argsort<int64_t>(in, out, axis_, stream());
case float32:
return argsort<float>(in, out, axis_, stream());
case float64:
return argsort<double>(in, out, axis_, stream());
case float16:
return argsort<float16_t>(in, out, axis_, stream());
case bfloat16:
return argsort<bfloat16_t>(in, out, axis_, stream());
case complex64:
return argsort<complex64_t>(in, out, axis_, stream());
}
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
axis_ = axis_]() mutable {
switch (in.dtype()) {
case bool_:
return argsort<bool>(in, out, axis_);
case uint8:
return argsort<uint8_t>(in, out, axis_);
case uint16:
return argsort<uint16_t>(in, out, axis_);
case uint32:
return argsort<uint32_t>(in, out, axis_);
case uint64:
return argsort<uint64_t>(in, out, axis_);
case int8:
return argsort<int8_t>(in, out, axis_);
case int16:
return argsort<int16_t>(in, out, axis_);
case int32:
return argsort<int32_t>(in, out, axis_);
case int64:
return argsort<int64_t>(in, out, axis_);
case float32:
return argsort<float>(in, out, axis_);
case float64:
return argsort<double>(in, out, axis_);
case float16:
return argsort<float16_t>(in, out, axis_);
case bfloat16:
return argsort<bfloat16_t>(in, out, axis_);
case complex64:
return argsort<complex64_t>(in, out, axis_);
}
});
}
void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
case bool_:
return sort<bool>(in, out, axis_, stream());
case uint8:
return sort<uint8_t>(in, out, axis_, stream());
case uint16:
return sort<uint16_t>(in, out, axis_, stream());
case uint32:
return sort<uint32_t>(in, out, axis_, stream());
case uint64:
return sort<uint64_t>(in, out, axis_, stream());
case int8:
return sort<int8_t>(in, out, axis_, stream());
case int16:
return sort<int16_t>(in, out, axis_, stream());
case int32:
return sort<int32_t>(in, out, axis_, stream());
case int64:
return sort<int64_t>(in, out, axis_, stream());
case float32:
return sort<float>(in, out, axis_, stream());
case float64:
return sort<double>(in, out, axis_, stream());
case float16:
return sort<float16_t>(in, out, axis_, stream());
case bfloat16:
return sort<bfloat16_t>(in, out, axis_, stream());
case complex64:
return sort<complex64_t>(in, out, axis_, stream());
}
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
encoder.dispatch(
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
switch (out.dtype()) {
case bool_:
return sort<bool>(out, axis_);
case uint8:
return sort<uint8_t>(out, axis_);
case uint16:
return sort<uint16_t>(out, axis_);
case uint32:
return sort<uint32_t>(out, axis_);
case uint64:
return sort<uint64_t>(out, axis_);
case int8:
return sort<int8_t>(out, axis_);
case int16:
return sort<int16_t>(out, axis_);
case int32:
return sort<int32_t>(out, axis_);
case int64:
return sort<int64_t>(out, axis_);
case float32:
return sort<float>(out, axis_);
case float64:
return sort<double>(out, axis_);
case float16:
return sort<float16_t>(out, axis_);
case bfloat16:
return sort<bfloat16_t>(out, axis_);
case complex64:
return sort<complex64_t>(out, axis_);
}
});
}
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
case bool_:
return argpartition<bool>(in, out, axis_, kth_, stream());
case uint8:
return argpartition<uint8_t>(in, out, axis_, kth_, stream());
case uint16:
return argpartition<uint16_t>(in, out, axis_, kth_, stream());
case uint32:
return argpartition<uint32_t>(in, out, axis_, kth_, stream());
case uint64:
return argpartition<uint64_t>(in, out, axis_, kth_, stream());
case int8:
return argpartition<int8_t>(in, out, axis_, kth_, stream());
case int16:
return argpartition<int16_t>(in, out, axis_, kth_, stream());
case int32:
return argpartition<int32_t>(in, out, axis_, kth_, stream());
case int64:
return argpartition<int64_t>(in, out, axis_, kth_, stream());
case float32:
return argpartition<float>(in, out, axis_, kth_, stream());
case float64:
return argpartition<double>(in, out, axis_, kth_, stream());
case float16:
return argpartition<float16_t>(in, out, axis_, kth_, stream());
case bfloat16:
return argpartition<bfloat16_t>(in, out, axis_, kth_, stream());
case complex64:
return argpartition<complex64_t>(in, out, axis_, kth_, stream());
}
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_input_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
axis_ = axis_,
kth_ = kth_]() mutable {
switch (in.dtype()) {
case bool_:
return argpartition<bool>(in, out, axis_, kth_);
case uint8:
return argpartition<uint8_t>(in, out, axis_, kth_);
case uint16:
return argpartition<uint16_t>(in, out, axis_, kth_);
case uint32:
return argpartition<uint32_t>(in, out, axis_, kth_);
case uint64:
return argpartition<uint64_t>(in, out, axis_, kth_);
case int8:
return argpartition<int8_t>(in, out, axis_, kth_);
case int16:
return argpartition<int16_t>(in, out, axis_, kth_);
case int32:
return argpartition<int32_t>(in, out, axis_, kth_);
case int64:
return argpartition<int64_t>(in, out, axis_, kth_);
case float32:
return argpartition<float>(in, out, axis_, kth_);
case float64:
return argpartition<double>(in, out, axis_, kth_);
case float16:
return argpartition<float16_t>(in, out, axis_, kth_);
case bfloat16:
return argpartition<bfloat16_t>(in, out, axis_, kth_);
case complex64:
return argpartition<complex64_t>(in, out, axis_, kth_);
}
});
}
void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
case bool_:
return partition<bool>(in, out, axis_, kth_, stream());
case uint8:
return partition<uint8_t>(in, out, axis_, kth_, stream());
case uint16:
return partition<uint16_t>(in, out, axis_, kth_, stream());
case uint32:
return partition<uint32_t>(in, out, axis_, kth_, stream());
case uint64:
return partition<uint64_t>(in, out, axis_, kth_, stream());
case int8:
return partition<int8_t>(in, out, axis_, kth_, stream());
case int16:
return partition<int16_t>(in, out, axis_, kth_, stream());
case int32:
return partition<int32_t>(in, out, axis_, kth_, stream());
case int64:
return partition<int64_t>(in, out, axis_, kth_, stream());
case float32:
return partition<float>(in, out, axis_, kth_, stream());
case float64:
return partition<double>(in, out, axis_, kth_, stream());
case float16:
return partition<float16_t>(in, out, axis_, kth_, stream());
case bfloat16:
return partition<bfloat16_t>(in, out, axis_, kth_, stream());
case complex64:
return partition<complex64_t>(in, out, axis_, kth_, stream());
}
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
axis_ = axis_,
kth_ = kth_]() mutable {
switch (out.dtype()) {
case bool_:
return partition<bool>(out, axis_, kth_);
case uint8:
return partition<uint8_t>(out, axis_, kth_);
case uint16:
return partition<uint16_t>(out, axis_, kth_);
case uint32:
return partition<uint32_t>(out, axis_, kth_);
case uint64:
return partition<uint64_t>(out, axis_, kth_);
case int8:
return partition<int8_t>(out, axis_, kth_);
case int16:
return partition<int16_t>(out, axis_, kth_);
case int32:
return partition<int32_t>(out, axis_, kth_);
case int64:
return partition<int64_t>(out, axis_, kth_);
case float32:
return partition<float>(out, axis_, kth_);
case float64:
return partition<double>(out, axis_, kth_);
case float16:
return partition<float16_t>(out, axis_, kth_);
case bfloat16:
return partition<bfloat16_t>(out, axis_, kth_);
case complex64:
return partition<complex64_t>(out, axis_, kth_);
}
});
}
} // namespace mlx::core