mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
redesign for faster cpu/gpu synch (#1869)
* redesign for faster cpu/gpu synch * load + more async CPU * use command encoder API and move more ops to use it * make fence back-end generic + CPU only fence * faster build * fix async eval * fixes + handle temporaries * fix / improve cpu conv * remove unused status, fix siblings * fix extensions * fix * fix no cpu build * format * comments * fix perf regression, remove unecessary abort * fix events, task limit cpu * fix waiting * fix donation / temporaries in normalization
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#include <limits>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -140,25 +141,33 @@ void reduction_op(
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
U init,
|
||||
Op op) {
|
||||
Stream stream) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
ReductionPlan plan = get_reduction_plan(x, axes);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto in_ptr = x.data<T>();
|
||||
auto out_ptr = out.data<U>();
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
U* out_ptr = out.data<U>();
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(x.data<T>(), out_ptr, x.size(), op, init);
|
||||
encoder.dispatch([in_ptr, out_ptr, init, size = x.size()]() {
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(in_ptr, out_ptr, size, Op{}, init);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
|
||||
int reduction_size = plan.shape[0];
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) {
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(x_ptr, out_ptr, reduction_size, op, init);
|
||||
}
|
||||
encoder.dispatch(
|
||||
[in_ptr, out_ptr, init, reduction_size, size = out.size()]() mutable {
|
||||
for (int i = 0; i < size; i++, out_ptr++, in_ptr += reduction_size) {
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -166,34 +175,43 @@ void reduction_op(
|
||||
int reduction_size = plan.shape.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
// Unrolling the following loop (and implementing it in order for
|
||||
// ContiguousReduce) should hold extra performance boost.
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(x_ptr + offset, out_ptr, reduction_size, op, init);
|
||||
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
reduction_size,
|
||||
size = out.size(),
|
||||
plan = std::move(plan),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides)]() mutable {
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < size; i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(
|
||||
in_ptr + offset, out_ptr, reduction_size, Op{}, init);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < size; i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
contiguous_reduce(
|
||||
in_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
Op{},
|
||||
init);
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
contiguous_reduce(
|
||||
x_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
op,
|
||||
init);
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
}
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -202,14 +220,20 @@ void reduction_op(
|
||||
size_t reduction_stride = plan.strides.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(x_ptr, out_ptr, reduction_size, reduction_stride, op);
|
||||
x_ptr += reduction_stride * reduction_size;
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
size = out.size()]() mutable {
|
||||
for (int i = 0; i < size; i += reduction_stride) {
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
|
||||
in_ptr += reduction_stride * reduction_size;
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -219,53 +243,69 @@ void reduction_op(
|
||||
size_t reduction_stride = plan.strides.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(
|
||||
x_ptr + offset, out_ptr, reduction_size, reduction_stride, op);
|
||||
out_ptr += reduction_stride;
|
||||
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
size = out.size(),
|
||||
plan = std::move(plan),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides)]() mutable {
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < size; i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(
|
||||
in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < size; i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
strided_reduce(
|
||||
in_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
Op{});
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
strided_reduce(
|
||||
x_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
op);
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == GeneralReduce) {
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
U val = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
val = op(val, *(x_ptr + offset + extra_offset));
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
*out_ptr = val;
|
||||
}
|
||||
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
size = out.size(),
|
||||
plan = std::move(plan),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides)]() mutable {
|
||||
for (int i = 0; i < size; i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
U val = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
val = Op{}(val, *(in_ptr + offset + extra_offset));
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
*out_ptr = val;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -394,11 +434,12 @@ void reduce_dispatch_and_or(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
const std::vector<int>& axes,
|
||||
Stream stream) {
|
||||
if (rtype == Reduce::And) {
|
||||
reduction_op<InT, bool>(in, out, axes, true, AndReduce());
|
||||
reduction_op<InT, bool, AndReduce>(in, out, axes, true, stream);
|
||||
} else {
|
||||
reduction_op<InT, bool>(in, out, axes, false, OrReduce());
|
||||
reduction_op<InT, bool, OrReduce>(in, out, axes, false, stream);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -407,18 +448,19 @@ void reduce_dispatch_sum_prod(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
const std::vector<int>& axes,
|
||||
Stream stream) {
|
||||
if (rtype == Reduce::Sum) {
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, SumReduce());
|
||||
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0, stream);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, SumReduce());
|
||||
reduction_op<InT, InT, SumReduce>(in, out, axes, 0, stream);
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t>(in, out, axes, 1, ProdReduce());
|
||||
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1, stream);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 1, ProdReduce());
|
||||
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -428,13 +470,14 @@ void reduce_dispatch_min_max(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
const std::vector<int>& axes,
|
||||
Stream stream) {
|
||||
if (rtype == Reduce::Max) {
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
||||
reduction_op<InT, InT, MaxReduce>(in, out, axes, init, stream);
|
||||
} else {
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
||||
reduction_op<InT, InT, MinReduce>(in, out, axes, init, stream);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -448,24 +491,28 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_and_or<int8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
case float16:
|
||||
case bfloat16:
|
||||
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_and_or<int16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint32:
|
||||
case int32:
|
||||
case float32:
|
||||
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_and_or<int32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint64:
|
||||
case int64:
|
||||
case float64:
|
||||
case complex64:
|
||||
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_and_or<int64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
}
|
||||
break;
|
||||
@@ -476,34 +523,43 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_sum_prod<int8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_sum_prod<int16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_sum_prod<int32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_sum_prod<int64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_sum_prod<float16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_sum_prod<float>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_sum_prod<double>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_sum_prod<complex64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
}
|
||||
break;
|
||||
@@ -512,46 +568,59 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case Reduce::Min: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<uint8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<uint16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<uint32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<uint64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<uint8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<uint16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<int32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<int64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<float16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<float>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<double>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<bfloat16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
|
||||
reduce_dispatch_min_max<complex64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user