mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
more donation take 2
This commit is contained in:
parent
35dc8580e3
commit
f140792f1c
@ -20,21 +20,16 @@ void eval(array& arr) {
|
||||
}
|
||||
arr.primitive().eval_cpu(arr.inputs(), outputs);
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
}
|
||||
// Remove the output if it was donated to by an input
|
||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
void finalize(
|
||||
Stream s,
|
||||
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers) {
|
||||
auto& encoder = cpu::get_command_encoder(s);
|
||||
encoder.dispatch([buffers = std::move(buffers),
|
||||
temps = std::move(encoder.temporaries())]() {});
|
||||
encoder.dispatch([s,
|
||||
buffers = std::move(retain_buffers),
|
||||
temps = std::move(encoder.temporaries())]() {
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cpu
|
||||
|
@ -2,11 +2,16 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
void eval(array& arr);
|
||||
void finalize(
|
||||
Stream s,
|
||||
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers);
|
||||
|
||||
} // namespace mlx::core::cpu
|
||||
|
@ -1,5 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <sstream>
|
||||
|
||||
|
@ -40,46 +40,32 @@ void eval(array& arr) {
|
||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
}
|
||||
// Remove the output if it was donated to by an input
|
||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
if (d.command_buffer_needs_commit(s.index)) {
|
||||
void finalize(
|
||||
Stream s,
|
||||
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers,
|
||||
bool force_commit) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
if (d.command_buffer_needs_commit(s.index) || force_commit) {
|
||||
d.end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
[s, buffers = std::move(retain_buffers)](MTL::CommandBuffer* cbuf) {
|
||||
scheduler::notify_task_completion(s);
|
||||
check_error(cbuf);
|
||||
});
|
||||
d.commit_command_buffer(s.index);
|
||||
d.get_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
[s, buffers = std::move(retain_buffers)](MTL::CommandBuffer* cbuf) {
|
||||
check_error(cbuf);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void finalize(Stream s) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
d.end_encoding(s.index);
|
||||
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
|
||||
d.commit_command_buffer(s.index);
|
||||
d.get_command_buffer(s.index);
|
||||
}
|
||||
|
||||
void synchronize(Stream s) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/stream.h"
|
||||
@ -15,7 +16,10 @@ void new_stream(Stream stream);
|
||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||
|
||||
void eval(array& arr);
|
||||
void finalize(Stream s);
|
||||
void finalize(
|
||||
Stream s,
|
||||
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers,
|
||||
bool force_commit);
|
||||
void synchronize(Stream s);
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@ -21,7 +21,10 @@ void eval(array&) {
|
||||
"[metal::eval] Cannot eval on GPU without metal backend");
|
||||
}
|
||||
|
||||
void finalize(Stream) {
|
||||
void finalize(
|
||||
Stream,
|
||||
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers,
|
||||
bool) {
|
||||
throw std::runtime_error(
|
||||
"[metal::finalize] Cannot finalize GPU without metal backend");
|
||||
}
|
||||
|
@ -160,6 +160,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
}
|
||||
|
||||
auto it = cache.find(in.id());
|
||||
|
||||
it->second -= 1;
|
||||
|
||||
if (it->second != 0) {
|
||||
@ -180,6 +181,8 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<std::uintptr_t, std::weak_ptr<array::Data>>
|
||||
unretained_buffers;
|
||||
while (!tape.empty()) {
|
||||
auto arr = std::move(tape.back());
|
||||
tape.pop_back();
|
||||
@ -225,7 +228,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
// Commit any open streams
|
||||
for (auto& [_, e] : events) {
|
||||
if (e.stream().device == Device::gpu) {
|
||||
metal::finalize(e.stream());
|
||||
metal::finalize(e.stream(), {}, true);
|
||||
}
|
||||
}
|
||||
scheduler::wait_for_one();
|
||||
@ -246,24 +249,59 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
};
|
||||
|
||||
arr.set_status(array::Status::evaluated);
|
||||
// TODO Maybe always want the fence coherent kernel in the same cbuf
|
||||
// as the other kernels?
|
||||
|
||||
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers;
|
||||
|
||||
maybe_update_fence(arr);
|
||||
for (auto& sib : arr.siblings()) {
|
||||
sib.set_status(array::Status::evaluated);
|
||||
maybe_update_fence(sib);
|
||||
retain_buffers.insert(sib.data_shared_ptr());
|
||||
}
|
||||
|
||||
for (auto& in : arr.inputs()) {
|
||||
retain_buffers.insert(in.data_shared_ptr());
|
||||
}
|
||||
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
}
|
||||
|
||||
for (auto it = retain_buffers.begin(); it != retain_buffers.end();) {
|
||||
if (it->use_count() > 1) {
|
||||
// At this point the buffer must be in one of two states:
|
||||
// 1. Held by another array
|
||||
// 2. Held from a prevous async_eval
|
||||
unretained_buffers.emplace(std::uintptr_t(it->get()), *it);
|
||||
it = retain_buffers.erase(it);
|
||||
} else {
|
||||
unretained_buffers.erase(std::uintptr_t(it->get()));
|
||||
++it;
|
||||
}
|
||||
}
|
||||
if (stream.device == Device::gpu) {
|
||||
metal::finalize(stream, std::move(retain_buffers), false);
|
||||
} else {
|
||||
cpu::finalize(stream, std::move(retain_buffers));
|
||||
}
|
||||
}
|
||||
// Signal the event in its stream
|
||||
for (auto& [_, e] : events) {
|
||||
auto s = e.stream();
|
||||
e.signal(s);
|
||||
std::unordered_set<std::shared_ptr<array::Data>> retain;
|
||||
if (s == stream) {
|
||||
for (auto& [_, b] : unretained_buffers) {
|
||||
auto ptr = b.lock();
|
||||
if (ptr) {
|
||||
retain.insert(ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (s.device == Device::gpu) {
|
||||
metal::finalize(s);
|
||||
metal::finalize(s, std::move(retain), true);
|
||||
} else {
|
||||
cpu::finalize(s, std::move(retain));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -195,6 +195,36 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
mx.eval(z)
|
||||
mx.set_memory_limit(old_limit)
|
||||
|
||||
def test_donation_multiple_inputs(self):
|
||||
def fun(its, x, y):
|
||||
for _ in range(its):
|
||||
a = x + y # y should donate
|
||||
b = x + a # x should donate
|
||||
x, y = a, b
|
||||
return x, y
|
||||
|
||||
x = mx.zeros((128, 128))
|
||||
y = mx.zeros((128, 128))
|
||||
mx.reset_peak_memory()
|
||||
a, b = fun(2, x, y)
|
||||
mx.eval(a, b)
|
||||
mx.synchronize()
|
||||
mem2 = mx.get_peak_memory()
|
||||
a, b = fun(10, x, y)
|
||||
mx.eval(a, b)
|
||||
mx.synchronize()
|
||||
mem10 = mx.get_peak_memory()
|
||||
self.assertEqual(mem2, mem10)
|
||||
|
||||
def test_async_with_delete(self):
|
||||
a = mx.ones((5, 5))
|
||||
for _ in range(100):
|
||||
a = mx.abs(a)
|
||||
mx.async_eval(a)
|
||||
del a
|
||||
mx.clear_cache()
|
||||
mx.synchronize()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user