more donation take 2

This commit is contained in:
Awni Hannun 2025-03-04 12:50:46 -08:00 committed by Awni Hannun
parent 35dc8580e3
commit f140792f1c
8 changed files with 106 additions and 46 deletions

View File

@ -20,21 +20,16 @@ void eval(array& arr) {
}
arr.primitive().eval_cpu(arr.inputs(), outputs);
}
}
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
void finalize(
Stream s,
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers) {
auto& encoder = cpu::get_command_encoder(s);
encoder.dispatch([buffers = std::move(buffers),
temps = std::move(encoder.temporaries())]() {});
encoder.dispatch([s,
buffers = std::move(retain_buffers),
temps = std::move(encoder.temporaries())]() {
});
}
} // namespace mlx::core::cpu

View File

@ -2,11 +2,16 @@
#pragma once
#include <unordered_set>
#include "mlx/array.h"
#include "mlx/stream.h"
namespace mlx::core::cpu {
void eval(array& arr);
void finalize(
Stream s,
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers);
} // namespace mlx::core::cpu

View File

@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstdlib>
#include <sstream>

View File

@ -40,46 +40,32 @@ void eval(array& arr) {
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
if (d.command_buffer_needs_commit(s.index)) {
void finalize(
Stream s,
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers,
bool force_commit) {
auto pool = new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
if (d.command_buffer_needs_commit(s.index) || force_commit) {
d.end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
[s, buffers = std::move(retain_buffers)](MTL::CommandBuffer* cbuf) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
[s, buffers = std::move(retain_buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
});
}
}
void finalize(Stream s) {
auto pool = new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
d.end_encoding(s.index);
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
}
void synchronize(Stream s) {
auto pool = new_scoped_memory_pool();
auto& d = metal::device(s.device);

View File

@ -4,6 +4,7 @@
#include <future>
#include <memory>
#include <unordered_set>
#include "mlx/array.h"
#include "mlx/stream.h"
@ -15,7 +16,10 @@ void new_stream(Stream stream);
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
void eval(array& arr);
void finalize(Stream s);
void finalize(
Stream s,
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers,
bool force_commit);
void synchronize(Stream s);
} // namespace mlx::core::metal

View File

@ -21,7 +21,10 @@ void eval(array&) {
"[metal::eval] Cannot eval on GPU without metal backend");
}
void finalize(Stream) {
void finalize(
Stream,
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers,
bool) {
throw std::runtime_error(
"[metal::finalize] Cannot finalize GPU without metal backend");
}

View File

@ -160,6 +160,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
}
auto it = cache.find(in.id());
it->second -= 1;
if (it->second != 0) {
@ -180,6 +181,8 @@ array eval_impl(std::vector<array> outputs, bool async) {
}
}
std::unordered_map<std::uintptr_t, std::weak_ptr<array::Data>>
unretained_buffers;
while (!tape.empty()) {
auto arr = std::move(tape.back());
tape.pop_back();
@ -225,7 +228,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
// Commit any open streams
for (auto& [_, e] : events) {
if (e.stream().device == Device::gpu) {
metal::finalize(e.stream());
metal::finalize(e.stream(), {}, true);
}
}
scheduler::wait_for_one();
@ -246,24 +249,59 @@ array eval_impl(std::vector<array> outputs, bool async) {
};
arr.set_status(array::Status::evaluated);
// TODO Maybe always want the fence coherent kernel in the same cbuf
// as the other kernels?
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers;
maybe_update_fence(arr);
for (auto& sib : arr.siblings()) {
sib.set_status(array::Status::evaluated);
maybe_update_fence(sib);
retain_buffers.insert(sib.data_shared_ptr());
}
for (auto& in : arr.inputs()) {
retain_buffers.insert(in.data_shared_ptr());
}
if (!arr.is_tracer()) {
arr.detach();
}
}
for (auto it = retain_buffers.begin(); it != retain_buffers.end();) {
if (it->use_count() > 1) {
// At this point the buffer must be in one of two states:
// 1. Held by another array
// 2. Held from a prevous async_eval
unretained_buffers.emplace(std::uintptr_t(it->get()), *it);
it = retain_buffers.erase(it);
} else {
unretained_buffers.erase(std::uintptr_t(it->get()));
++it;
}
}
if (stream.device == Device::gpu) {
metal::finalize(stream, std::move(retain_buffers), false);
} else {
cpu::finalize(stream, std::move(retain_buffers));
}
}
// Signal the event in its stream
for (auto& [_, e] : events) {
auto s = e.stream();
e.signal(s);
std::unordered_set<std::shared_ptr<array::Data>> retain;
if (s == stream) {
for (auto& [_, b] : unretained_buffers) {
auto ptr = b.lock();
if (ptr) {
retain.insert(ptr);
}
}
}
if (s.device == Device::gpu) {
metal::finalize(s);
metal::finalize(s, std::move(retain), true);
} else {
cpu::finalize(s, std::move(retain));
}
}

View File

@ -195,6 +195,36 @@ class TestEval(mlx_tests.MLXTestCase):
mx.eval(z)
mx.set_memory_limit(old_limit)
def test_donation_multiple_inputs(self):
def fun(its, x, y):
for _ in range(its):
a = x + y # y should donate
b = x + a # x should donate
x, y = a, b
return x, y
x = mx.zeros((128, 128))
y = mx.zeros((128, 128))
mx.reset_peak_memory()
a, b = fun(2, x, y)
mx.eval(a, b)
mx.synchronize()
mem2 = mx.get_peak_memory()
a, b = fun(10, x, y)
mx.eval(a, b)
mx.synchronize()
mem10 = mx.get_peak_memory()
self.assertEqual(mem2, mem10)
def test_async_with_delete(self):
a = mx.ones((5, 5))
for _ in range(100):
a = mx.abs(a)
mx.async_eval(a)
del a
mx.clear_cache()
mx.synchronize()
if __name__ == "__main__":
unittest.main()