Buffer Donation (#519)

* buffer donation

* fix to move shared pointer

* format

* gpu in place for copy and binary

* revert ops test

* cpu in place

* a little cleanup

* remove useless bench
This commit is contained in:
Awni Hannun
2024-01-26 16:30:33 -08:00
committed by GitHub
parent 07f35c9d8a
commit 8993382aaa
12 changed files with 199 additions and 178 deletions

View File

@@ -12,11 +12,15 @@ namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
@@ -67,7 +71,8 @@ void copy_gpu_inplace(
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
bool donate_in = in.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, donate_in ? out : in, 0);
set_array_buffer(compute_encoder, out, 1);
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {

View File

@@ -64,14 +64,23 @@ std::function<void()> make_task(
auto command_buffer = increment_command_buffer(s);
auto outputs = arr.outputs();
arr.primitive().eval_gpu(arr.inputs(), outputs);
std::vector<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.push_back(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.push_back(s.data_shared_ptr());
}
if (!arr.is_tracer()) {
arr.detach();
}
if (p) {
metal::device(s.device).end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, arr, p = std::move(p)](MTL::CommandBuffer* cbuf) mutable {
if (!arr.is_tracer()) {
arr.detach();
}
[s, buffers = std::move(buffers), p = std::move(p)](
MTL::CommandBuffer* cbuf) {
p->set_value();
scheduler::notify_task_completion(s);
check_error(cbuf);
@@ -79,10 +88,7 @@ std::function<void()> make_task(
metal::device(s.device).commit_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, arr](MTL::CommandBuffer* cbuf) mutable {
if (!arr.is_tracer()) {
arr.detach();
}
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
});
}

View File

@@ -27,8 +27,8 @@ void binary_op(
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
set_binary_op_output_data(a, b, outputs[0], bopt, true);
set_binary_op_output_data(a, b, outputs[1], bopt, true);
auto& out = outputs[0];
if (out.size() == 0) {
@@ -69,8 +69,14 @@ void binary_op(
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
// - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated
// otherwise it goes to the second output
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, donate_a ? outputs[0] : a, 0);
set_array_buffer(
compute_encoder, donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
set_array_buffer(compute_encoder, outputs[0], 2);
set_array_buffer(compute_encoder, outputs[1], 3);
@@ -122,7 +128,7 @@ void binary_op(
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);
set_binary_op_output_data(a, b, out, bopt, true);
if (out.size() == 0) {
return;
}
@@ -161,8 +167,10 @@ void binary_op(
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, donate_a ? out : a, 0);
set_array_buffer(compute_encoder, donate_b ? out : b, 1);
set_array_buffer(compute_encoder, out, 2);
if (bopt == General) {
@@ -212,11 +220,15 @@ void unary_op(
auto& in = inputs[0];
bool contig = in.flags().contiguous;
if (contig) {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
@@ -240,7 +252,8 @@ void unary_op(
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
set_array_buffer(compute_encoder, out, 1);
if (!contig) {
compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);