redesign for faster cpu/gpu synch (#1869)

* redesign for faster cpu/gpu synch

* load + more async CPU

* use command encoder API and move more ops to use it

* make fence back-end generic + CPU only fence

* faster build

* fix async eval

* fixes + handle temporaries

* fix / improve cpu conv

* remove unused status, fix siblings

* fix extensions

* fix

* fix no cpu build

* format

* comments

* fix perf regression, remove unecessary abort

* fix events, task limit cpu

* fix waiting

* fix donation / temporaries in normalization
This commit is contained in:
Awni Hannun
2025-03-06 19:23:38 -08:00
committed by GitHub
parent 5245f12a46
commit c4230747a1
103 changed files with 5013 additions and 3873 deletions

View File

@@ -5,6 +5,7 @@
#include <limits>
#include "mlx/backend/common/reduce.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
@@ -140,25 +141,33 @@ void reduction_op(
array& out,
const std::vector<int>& axes,
U init,
Op op) {
Stream stream) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
ReductionPlan plan = get_reduction_plan(x, axes);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_output_array(out);
auto in_ptr = x.data<T>();
auto out_ptr = out.data<U>();
if (plan.type == ContiguousAllReduce) {
U* out_ptr = out.data<U>();
*out_ptr = init;
contiguous_reduce(x.data<T>(), out_ptr, x.size(), op, init);
encoder.dispatch([in_ptr, out_ptr, init, size = x.size()]() {
*out_ptr = init;
contiguous_reduce(in_ptr, out_ptr, size, Op{}, init);
});
return;
}
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape[0];
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
*out_ptr = init;
contiguous_reduce(x_ptr, out_ptr, reduction_size, op, init);
}
encoder.dispatch(
[in_ptr, out_ptr, init, reduction_size, size = out.size()]() mutable {
for (int i = 0; i < size; i++, out_ptr++, in_ptr += reduction_size) {
*out_ptr = init;
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
}
});
return;
}
@@ -166,34 +175,43 @@ void reduction_op(
int reduction_size = plan.shape.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
// Unrolling the following loop (and implementing it in order for
// ContiguousReduce) should hold extra performance boost.
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
contiguous_reduce(x_ptr + offset, out_ptr, reduction_size, op, init);
encoder.dispatch([in_ptr,
out_ptr,
init,
reduction_size,
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
if (plan.shape.size() == 0) {
for (int i = 0; i < size; i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
contiguous_reduce(
in_ptr + offset, out_ptr, reduction_size, Op{}, init);
}
} else {
for (int i = 0; i < size; i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
nd_loop(
[&](int extra_offset) {
contiguous_reduce(
in_ptr + offset + extra_offset,
out_ptr,
reduction_size,
Op{},
init);
},
plan.shape,
plan.strides);
}
}
} else {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
*out_ptr = init;
nd_loop(
[&](int extra_offset) {
contiguous_reduce(
x_ptr + offset + extra_offset,
out_ptr,
reduction_size,
op,
init);
},
plan.shape,
plan.strides);
}
}
});
return;
}
@@ -202,14 +220,20 @@ void reduction_op(
size_t reduction_stride = plan.strides.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
for (int i = 0; i < out.size(); i += reduction_stride) {
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(x_ptr, out_ptr, reduction_size, reduction_stride, op);
x_ptr += reduction_stride * reduction_size;
out_ptr += reduction_stride;
}
encoder.dispatch([in_ptr,
out_ptr,
init,
reduction_size,
reduction_stride,
size = out.size()]() mutable {
for (int i = 0; i < size; i += reduction_stride) {
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
in_ptr += reduction_stride * reduction_size;
out_ptr += reduction_stride;
}
});
return;
}
@@ -219,53 +243,69 @@ void reduction_op(
size_t reduction_stride = plan.strides.back();
plan.shape.pop_back();
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(
x_ptr + offset, out_ptr, reduction_size, reduction_stride, op);
out_ptr += reduction_stride;
encoder.dispatch([in_ptr,
out_ptr,
init,
reduction_size,
reduction_stride,
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
if (plan.shape.size() == 0) {
for (int i = 0; i < size; i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
strided_reduce(
in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
out_ptr += reduction_stride;
}
} else {
for (int i = 0; i < size; i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
nd_loop(
[&](int extra_offset) {
strided_reduce(
in_ptr + offset + extra_offset,
out_ptr,
reduction_size,
reduction_stride,
Op{});
},
plan.shape,
plan.strides);
out_ptr += reduction_stride;
}
}
} else {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
std::fill_n(out_ptr, reduction_stride, init);
nd_loop(
[&](int extra_offset) {
strided_reduce(
x_ptr + offset + extra_offset,
out_ptr,
reduction_size,
reduction_stride,
op);
},
plan.shape,
plan.strides);
out_ptr += reduction_stride;
}
}
});
return;
}
if (plan.type == GeneralReduce) {
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
U val = init;
nd_loop(
[&](int extra_offset) {
val = op(val, *(x_ptr + offset + extra_offset));
},
plan.shape,
plan.strides);
*out_ptr = val;
}
encoder.dispatch([in_ptr,
out_ptr,
init,
size = out.size(),
plan = std::move(plan),
shape = std::move(shape),
strides = std::move(strides)]() mutable {
for (int i = 0; i < size; i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
U val = init;
nd_loop(
[&](int extra_offset) {
val = Op{}(val, *(in_ptr + offset + extra_offset));
},
plan.shape,
plan.strides);
*out_ptr = val;
}
});
}
}
@@ -394,11 +434,12 @@ void reduce_dispatch_and_or(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
const std::vector<int>& axes,
Stream stream) {
if (rtype == Reduce::And) {
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
reduction_op<InT, bool, AndReduce>(in, out, axes, true, stream);
} else {
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
reduction_op<InT, bool, OrReduce>(in, out, axes, false, stream);
}
}
@@ -407,18 +448,19 @@ void reduce_dispatch_sum_prod(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
const std::vector<int>& axes,
Stream stream) {
if (rtype == Reduce::Sum) {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 0, SumReduce());
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0, stream);
} else {
reduction_op<InT, InT>(in, out, axes, 0, SumReduce());
reduction_op<InT, InT, SumReduce>(in, out, axes, 0, stream);
}
} else {
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 1, ProdReduce());
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1, stream);
} else {
reduction_op<InT, InT>(in, out, axes, 1, ProdReduce());
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1, stream);
}
}
}
@@ -428,13 +470,14 @@ void reduce_dispatch_min_max(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
const std::vector<int>& axes,
Stream stream) {
if (rtype == Reduce::Max) {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
reduction_op<InT, InT, MaxReduce>(in, out, axes, init, stream);
} else {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
reduction_op<InT, InT, MinReduce>(in, out, axes, init, stream);
}
}
@@ -448,24 +491,28 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
reduce_dispatch_and_or<int8_t>(
in, out, reduce_type_, axes_, stream());
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_and_or<int16_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
reduce_dispatch_and_or<int32_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint64:
case int64:
case float64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
reduce_dispatch_and_or<int64_t>(
in, out, reduce_type_, axes_, stream());
break;
}
break;
@@ -476,34 +523,43 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<int8_t>(
in, out, reduce_type_, axes_, stream());
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<int16_t>(
in, out, reduce_type_, axes_, stream());
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<int32_t>(
in, out, reduce_type_, axes_, stream());
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<int64_t>(
in, out, reduce_type_, axes_, stream());
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<float16_t>(
in, out, reduce_type_, axes_, stream());
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<bfloat16_t>(
in, out, reduce_type_, axes_, stream());
break;
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<float>(
in, out, reduce_type_, axes_, stream());
break;
case float64:
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<double>(
in, out, reduce_type_, axes_, stream());
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
reduce_dispatch_sum_prod<complex64_t>(
in, out, reduce_type_, axes_, stream());
break;
}
break;
@@ -512,46 +568,59 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case Reduce::Min: {
switch (in.dtype()) {
case bool_:
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_, stream());
break;
case uint8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint8_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint16_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint32:
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint32_t>(
in, out, reduce_type_, axes_, stream());
break;
case uint64:
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint64_t>(
in, out, reduce_type_, axes_, stream());
break;
case int8:
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint8_t>(
in, out, reduce_type_, axes_, stream());
break;
case int16:
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<uint16_t>(
in, out, reduce_type_, axes_, stream());
break;
case int32:
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<int32_t>(
in, out, reduce_type_, axes_, stream());
break;
case int64:
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<int64_t>(
in, out, reduce_type_, axes_, stream());
break;
case float16:
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<float16_t>(
in, out, reduce_type_, axes_, stream());
break;
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<float>(
in, out, reduce_type_, axes_, stream());
break;
case float64:
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<double>(
in, out, reduce_type_, axes_, stream());
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<bfloat16_t>(
in, out, reduce_type_, axes_, stream());
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
reduce_dispatch_min_max<complex64_t>(
in, out, reduce_type_, axes_, stream());
break;
}
break;