reduce binary size (#1952)

This commit is contained in:
Awni Hannun
2025-03-11 06:30:44 -07:00
committed by GitHub
parent 117e1355a2
commit 736a340478
16 changed files with 2145 additions and 2386 deletions

View File

@@ -140,34 +140,23 @@ void reduction_op(
const array& x,
array& out,
const std::vector<int>& axes,
U init,
Stream stream) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
U init) {
ReductionPlan plan = get_reduction_plan(x, axes);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_output_array(out);
auto in_ptr = x.data<T>();
auto out_ptr = out.data<U>();
if (plan.type == ContiguousAllReduce) {
encoder.dispatch([in_ptr, out_ptr, init, size = x.size()]() {
*out_ptr = init;
contiguous_reduce(in_ptr, out_ptr, size, Op{}, init);
});
*out_ptr = init;
contiguous_reduce(in_ptr, out_ptr, x.size(), Op{}, init);
return;
}
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape[0];
encoder.dispatch(
[in_ptr, out_ptr, init, reduction_size, size = out.size()]() mutable {
for (int i = 0; i < size; i++, out_ptr++, in_ptr += reduction_size) {
*out_ptr = init;
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
}
});
for (int i = 0; i < out.size(); i++, out_ptr++, in_ptr += reduction_size) {
*out_ptr = init;
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
}
return;
}
@@ -178,40 +167,29 @@ void reduction_op(
// Unrolling the following loop (and implementing it in order for
// ContiguousReduce) should hold extra performance boost.
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
encoder.dispatch([in_ptr,
out_ptr,
init,
reduction_size,
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
if (plan.shape.size() == 0) {
for (int i = 0; i < size; i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
contiguous_reduce(
in_ptr + offset, out_ptr, reduction_size, Op{}, init);
}
} else {
for (int i = 0; i < size; i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
nd_loop(
[&](int extra_offset) {
contiguous_reduce(
in_ptr + offset + extra_offset,
out_ptr,
reduction_size,
Op{},
init);
},
plan.shape,
plan.strides);
}
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
contiguous_reduce(in_ptr + offset, out_ptr, reduction_size, Op{}, init);
}
});
} else {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
nd_loop(
[&](int extra_offset) {
contiguous_reduce(
in_ptr + offset + extra_offset,
out_ptr,
reduction_size,
Op{},
init);
},
plan.shape,
plan.strides);
}
}
return;
}
@@ -220,20 +198,12 @@ void reduction_op(
size_t reduction_stride = plan.strides.back();
plan.shape.pop_back();
plan.strides.pop_back();
encoder.dispatch([in_ptr,
out_ptr,
init,
reduction_size,
reduction_stride,
size = out.size()]() mutable {
for (int i = 0; i < size; i += reduction_stride) {
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
in_ptr += reduction_stride * reduction_size;
out_ptr += reduction_stride;
}
});
for (int i = 0; i < out.size(); i += reduction_stride) {
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
in_ptr += reduction_stride * reduction_size;
out_ptr += reduction_stride;
}
return;
}
@@ -245,67 +215,49 @@ void reduction_op(
plan.strides.pop_back();
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
encoder.dispatch([in_ptr,
out_ptr,
init,
reduction_size,
reduction_stride,
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
if (plan.shape.size() == 0) {
for (int i = 0; i < size; i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(
in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
out_ptr += reduction_stride;
}
} else {
for (int i = 0; i < size; i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
nd_loop(
[&](int extra_offset) {
strided_reduce(
in_ptr + offset + extra_offset,
out_ptr,
reduction_size,
reduction_stride,
Op{});
},
plan.shape,
plan.strides);
out_ptr += reduction_stride;
}
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(
in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
out_ptr += reduction_stride;
}
});
} else {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
nd_loop(
[&](int extra_offset) {
strided_reduce(
in_ptr + offset + extra_offset,
out_ptr,
reduction_size,
reduction_stride,
Op{});
},
plan.shape,
plan.strides);
out_ptr += reduction_stride;
}
}
return;
}
if (plan.type == GeneralReduce) {
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
encoder.dispatch([in_ptr,
out_ptr,
init,
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
for (int i = 0; i < size; i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
U val = init;
nd_loop(
[&](int extra_offset) {
val = Op{}(val, *(in_ptr + offset + extra_offset));
},
plan.shape,
plan.strides);
*out_ptr = val;
}
});
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
U val = init;
nd_loop(
[&](int extra_offset) {
val = Op{}(val, *(in_ptr + offset + extra_offset));
},
plan.shape,
plan.strides);
*out_ptr = val;
}
}
}
@@ -434,12 +386,11 @@ void reduce_dispatch_and_or(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes,
Stream stream) {
const std::vector<int>& axes) {
if (rtype == Reduce::And) {
reduction_op<InT, bool, AndReduce>(in, out, axes, true, stream);
reduction_op<InT, bool, AndReduce>(in, out, axes, true);
} else {
reduction_op<InT, bool, OrReduce>(in, out, axes, false, stream);
reduction_op<InT, bool, OrReduce>(in, out, axes, false);
}
}
@@ -448,19 +399,18 @@ void reduce_dispatch_sum_prod(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes,
Stream stream) {
const std::vector<int>& axes) {
if (rtype == Reduce::Sum) {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0, stream);
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0);
} else {
reduction_op<InT, InT, SumReduce>(in, out, axes, 0, stream);
reduction_op<InT, InT, SumReduce>(in, out, axes, 0);
}
} else {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1, stream);
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1);
} else {
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1, stream);
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1);
}
}
}
@@ -470,162 +420,144 @@ void reduce_dispatch_min_max(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes,
Stream stream) {
const std::vector<int>& axes) {
if (rtype == Reduce::Max) {
auto init = Limits<InT>::min;
reduction_op<InT, InT, MaxReduce>(in, out, axes, init, stream);
reduction_op<InT, InT, MaxReduce>(in, out, axes, init);
} else {
auto init = Limits<InT>::max;
reduction_op<InT, InT, MinReduce>(in, out, axes, init, stream);
reduction_op<InT, InT, MinReduce>(in, out, axes, init);
}
}
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (reduce_type_) {
case Reduce::And:
case Reduce::Or: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(
in, out, reduce_type_, axes_, stream());
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint64:
case int64:
case float64:
case complex64:
reduce_dispatch_and_or<int64_t>(
in, out, reduce_type_, axes_, stream());
break;
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.dispatch([in = array::unsafe_weak_copy(in),
out = array::unsafe_weak_copy(out),
reduce_type_ = reduce_type_,
axes_ = axes_]() mutable {
switch (reduce_type_) {
case Reduce::And:
case Reduce::Or: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
case int64:
case float64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
}
break;
}
break;
}
case Reduce::Sum:
case Reduce::Prod: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(
in, out, reduce_type_, axes_, stream());
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(
in, out, reduce_type_, axes_, stream());
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(
in, out, reduce_type_, axes_, stream());
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(
in, out, reduce_type_, axes_, stream());
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(
in, out, reduce_type_, axes_, stream());
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(
in, out, reduce_type_, axes_, stream());
break;
case float32:
reduce_dispatch_sum_prod<float>(
in, out, reduce_type_, axes_, stream());
break;
case float64:
reduce_dispatch_sum_prod<double>(
in, out, reduce_type_, axes_, stream());
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(
in, out, reduce_type_, axes_, stream());
break;
case Reduce::Sum:
case Reduce::Prod: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
break;
case float64:
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break;
}
break;
}
case Reduce::Max:
case Reduce::Min: {
switch (in.dtype()) {
case bool_:
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_, stream());
break;
case uint8:
reduce_dispatch_min_max<uint8_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint16:
reduce_dispatch_min_max<uint16_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint32:
reduce_dispatch_min_max<uint32_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint64:
reduce_dispatch_min_max<uint64_t>(
in, out, reduce_type_, axes_, stream());
break;
case int8:
reduce_dispatch_min_max<uint8_t>(
in, out, reduce_type_, axes_, stream());
break;
case int16:
reduce_dispatch_min_max<uint16_t>(
in, out, reduce_type_, axes_, stream());
break;
case int32:
reduce_dispatch_min_max<int32_t>(
in, out, reduce_type_, axes_, stream());
break;
case int64:
reduce_dispatch_min_max<int64_t>(
in, out, reduce_type_, axes_, stream());
break;
case float16:
reduce_dispatch_min_max<float16_t>(
in, out, reduce_type_, axes_, stream());
break;
case float32:
reduce_dispatch_min_max<float>(
in, out, reduce_type_, axes_, stream());
break;
case float64:
reduce_dispatch_min_max<double>(
in, out, reduce_type_, axes_, stream());
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(
in, out, reduce_type_, axes_, stream());
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(
in, out, reduce_type_, axes_, stream());
break;
case Reduce::Max:
case Reduce::Min: {
switch (in.dtype()) {
case bool_:
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
break;
case uint8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case uint16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
break;
case int8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
break;
case int16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
break;
case int32:
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break;
case float64:
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break;
}
break;
}
}
});
}
} // namespace mlx::core