redesign for faster cpu/gpu synch (#1869)

* redesign for faster cpu/gpu synch

* load + more async CPU

* use command encoder API and move more ops to use it

* make fence back-end generic + CPU only fence

* faster build

* fix async eval

* fixes + handle temporaries

* fix / improve cpu conv

* remove unused status, fix siblings

* fix extensions

* fix

* fix no cpu build

* format

* comments

* fix perf regression, remove unecessary abort

* fix events, task limit cpu

* fix waiting

* fix donation / temporaries in normalization
This commit is contained in:
Awni Hannun
2025-03-06 19:23:38 -08:00
committed by GitHub
parent 5245f12a46
commit c4230747a1
103 changed files with 5013 additions and 3873 deletions

View File

@@ -11,6 +11,7 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/cpu/compiled_preamble.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/jit_compiler.h"
#include "mlx/device.h"
#include "mlx/graph_utils.h"
@@ -288,6 +289,7 @@ void Compiled::eval_cpu(
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, shape);
auto& encoder = cpu::get_command_encoder(stream());
// Handle all broadcasting and collect function input arguments
std::vector<void*> args;
@@ -298,6 +300,7 @@ void Compiled::eval_cpu(
continue;
}
auto& x = inputs[i];
encoder.set_input_array(x);
args.push_back((void*)x.data<void>());
if (contiguous || is_scalar(x)) {
@@ -356,18 +359,25 @@ void Compiled::eval_cpu(
});
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false);
inputs, outputs, inputs_, constant_ids_, contiguous);
for (auto& x : outputs) {
args.push_back(x.data<void>());
encoder.set_output_array(x);
}
Shape out_shape;
if (!contiguous) {
args.push_back((void*)outputs[0].shape().data());
out_shape = outputs[0].shape();
args.push_back((void*)out_shape.data());
} else {
args.push_back((void*)outputs[0].data_size());
}
auto fun = (void (*)(void**))fn_ptr;
fun(args.data());
encoder.dispatch(
[fun,
args = std::move(args),
strides = std::move(strides),
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
}
} // namespace mlx::core