mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
reduce binary size (#1952)
This commit is contained in:
@@ -140,34 +140,23 @@ void reduction_op(
|
||||
const array& x,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
U init,
|
||||
Stream stream) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
U init) {
|
||||
ReductionPlan plan = get_reduction_plan(x, axes);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto in_ptr = x.data<T>();
|
||||
auto out_ptr = out.data<U>();
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
encoder.dispatch([in_ptr, out_ptr, init, size = x.size()]() {
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(in_ptr, out_ptr, size, Op{}, init);
|
||||
});
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(in_ptr, out_ptr, x.size(), Op{}, init);
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
|
||||
int reduction_size = plan.shape[0];
|
||||
encoder.dispatch(
|
||||
[in_ptr, out_ptr, init, reduction_size, size = out.size()]() mutable {
|
||||
for (int i = 0; i < size; i++, out_ptr++, in_ptr += reduction_size) {
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
|
||||
}
|
||||
});
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++, in_ptr += reduction_size) {
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -178,40 +167,29 @@ void reduction_op(
|
||||
// Unrolling the following loop (and implementing it in order for
|
||||
// ContiguousReduce) should hold extra performance boost.
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
reduction_size,
|
||||
size = out.size(),
|
||||
plan = std::move(plan),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides)]() mutable {
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < size; i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(
|
||||
in_ptr + offset, out_ptr, reduction_size, Op{}, init);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < size; i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
contiguous_reduce(
|
||||
in_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
Op{},
|
||||
init);
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
}
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(in_ptr + offset, out_ptr, reduction_size, Op{}, init);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
contiguous_reduce(
|
||||
in_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
Op{},
|
||||
init);
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -220,20 +198,12 @@ void reduction_op(
|
||||
size_t reduction_stride = plan.strides.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
size = out.size()]() mutable {
|
||||
for (int i = 0; i < size; i += reduction_stride) {
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
|
||||
in_ptr += reduction_stride * reduction_size;
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
});
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
|
||||
in_ptr += reduction_stride * reduction_size;
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -245,67 +215,49 @@ void reduction_op(
|
||||
plan.strides.pop_back();
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
size = out.size(),
|
||||
plan = std::move(plan),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides)]() mutable {
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < size; i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(
|
||||
in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < size; i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
strided_reduce(
|
||||
in_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
Op{});
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(
|
||||
in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
});
|
||||
} else {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
strided_reduce(
|
||||
in_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
Op{});
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == GeneralReduce) {
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
size = out.size(),
|
||||
plan = std::move(plan),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides)]() mutable {
|
||||
for (int i = 0; i < size; i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
U val = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
val = Op{}(val, *(in_ptr + offset + extra_offset));
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
*out_ptr = val;
|
||||
}
|
||||
});
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
U val = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
val = Op{}(val, *(in_ptr + offset + extra_offset));
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
*out_ptr = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -434,12 +386,11 @@ void reduce_dispatch_and_or(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes,
|
||||
Stream stream) {
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::And) {
|
||||
reduction_op<InT, bool, AndReduce>(in, out, axes, true, stream);
|
||||
reduction_op<InT, bool, AndReduce>(in, out, axes, true);
|
||||
} else {
|
||||
reduction_op<InT, bool, OrReduce>(in, out, axes, false, stream);
|
||||
reduction_op<InT, bool, OrReduce>(in, out, axes, false);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -448,19 +399,18 @@ void reduce_dispatch_sum_prod(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes,
|
||||
Stream stream) {
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Sum) {
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0, stream);
|
||||
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0);
|
||||
} else {
|
||||
reduction_op<InT, InT, SumReduce>(in, out, axes, 0, stream);
|
||||
reduction_op<InT, InT, SumReduce>(in, out, axes, 0);
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1, stream);
|
||||
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1);
|
||||
} else {
|
||||
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1, stream);
|
||||
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -470,162 +420,144 @@ void reduce_dispatch_min_max(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes,
|
||||
Stream stream) {
|
||||
const std::vector<int>& axes) {
|
||||
if (rtype == Reduce::Max) {
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT, MaxReduce>(in, out, axes, init, stream);
|
||||
reduction_op<InT, InT, MaxReduce>(in, out, axes, init);
|
||||
} else {
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT, MinReduce>(in, out, axes, init, stream);
|
||||
reduction_op<InT, InT, MinReduce>(in, out, axes, init);
|
||||
}
|
||||
}
|
||||
|
||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
case Reduce::Or: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_and_or<int8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
case float16:
|
||||
case bfloat16:
|
||||
reduce_dispatch_and_or<int16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint32:
|
||||
case int32:
|
||||
case float32:
|
||||
reduce_dispatch_and_or<int32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint64:
|
||||
case int64:
|
||||
case float64:
|
||||
case complex64:
|
||||
reduce_dispatch_and_or<int64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
reduce_type_ = reduce_type_,
|
||||
axes_ = axes_]() mutable {
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
case Reduce::Or: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
case float16:
|
||||
case bfloat16:
|
||||
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
case int32:
|
||||
case float32:
|
||||
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
case int64:
|
||||
case float64:
|
||||
case complex64:
|
||||
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Reduce::Sum:
|
||||
case Reduce::Prod: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_sum_prod<float16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_sum_prod<float>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_sum_prod<double>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_sum_prod<complex64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case Reduce::Sum:
|
||||
case Reduce::Prod: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Reduce::Max:
|
||||
case Reduce::Min: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_min_max<uint8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_min_max<uint16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_min_max<uint32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_min_max<uint64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_min_max<uint8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_min_max<uint16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_min_max<int32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_min_max<int64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_min_max<float16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_min_max<float>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_min_max<double>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_min_max<bfloat16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_min_max<complex64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case Reduce::Max:
|
||||
case Reduce::Min: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user