mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
more donation take 2
This commit is contained in:
parent
35dc8580e3
commit
f140792f1c
@ -20,21 +20,16 @@ void eval(array& arr) {
|
|||||||
}
|
}
|
||||||
arr.primitive().eval_cpu(arr.inputs(), outputs);
|
arr.primitive().eval_cpu(arr.inputs(), outputs);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
void finalize(
|
||||||
for (auto& in : arr.inputs()) {
|
Stream s,
|
||||||
buffers.insert(in.data_shared_ptr());
|
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers) {
|
||||||
}
|
|
||||||
for (auto& s : arr.siblings()) {
|
|
||||||
buffers.insert(s.data_shared_ptr());
|
|
||||||
}
|
|
||||||
// Remove the output if it was donated to by an input
|
|
||||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
|
||||||
buffers.erase(it);
|
|
||||||
}
|
|
||||||
auto& encoder = cpu::get_command_encoder(s);
|
auto& encoder = cpu::get_command_encoder(s);
|
||||||
encoder.dispatch([buffers = std::move(buffers),
|
encoder.dispatch([s,
|
||||||
temps = std::move(encoder.temporaries())]() {});
|
buffers = std::move(retain_buffers),
|
||||||
|
temps = std::move(encoder.temporaries())]() {
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::cpu
|
} // namespace mlx::core::cpu
|
||||||
|
@ -2,11 +2,16 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
namespace mlx::core::cpu {
|
namespace mlx::core::cpu {
|
||||||
|
|
||||||
void eval(array& arr);
|
void eval(array& arr);
|
||||||
|
void finalize(
|
||||||
|
Stream s,
|
||||||
|
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers);
|
||||||
|
|
||||||
} // namespace mlx::core::cpu
|
} // namespace mlx::core::cpu
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
@ -40,46 +40,32 @@ void eval(array& arr) {
|
|||||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||||
}
|
}
|
||||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
}
|
||||||
for (auto& in : arr.inputs()) {
|
|
||||||
buffers.insert(in.data_shared_ptr());
|
|
||||||
}
|
|
||||||
for (auto& s : arr.siblings()) {
|
|
||||||
buffers.insert(s.data_shared_ptr());
|
|
||||||
}
|
|
||||||
// Remove the output if it was donated to by an input
|
|
||||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
|
||||||
buffers.erase(it);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (d.command_buffer_needs_commit(s.index)) {
|
void finalize(
|
||||||
|
Stream s,
|
||||||
|
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers,
|
||||||
|
bool force_commit) {
|
||||||
|
auto pool = new_scoped_memory_pool();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
auto command_buffer = d.get_command_buffer(s.index);
|
||||||
|
if (d.command_buffer_needs_commit(s.index) || force_commit) {
|
||||||
d.end_encoding(s.index);
|
d.end_encoding(s.index);
|
||||||
scheduler::notify_new_task(s);
|
scheduler::notify_new_task(s);
|
||||||
command_buffer->addCompletedHandler(
|
command_buffer->addCompletedHandler(
|
||||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
[s, buffers = std::move(retain_buffers)](MTL::CommandBuffer* cbuf) {
|
||||||
scheduler::notify_task_completion(s);
|
scheduler::notify_task_completion(s);
|
||||||
check_error(cbuf);
|
check_error(cbuf);
|
||||||
});
|
});
|
||||||
d.commit_command_buffer(s.index);
|
d.commit_command_buffer(s.index);
|
||||||
d.get_command_buffer(s.index);
|
|
||||||
} else {
|
} else {
|
||||||
command_buffer->addCompletedHandler(
|
command_buffer->addCompletedHandler(
|
||||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
[s, buffers = std::move(retain_buffers)](MTL::CommandBuffer* cbuf) {
|
||||||
check_error(cbuf);
|
check_error(cbuf);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void finalize(Stream s) {
|
|
||||||
auto pool = new_scoped_memory_pool();
|
|
||||||
auto& d = metal::device(s.device);
|
|
||||||
auto cb = d.get_command_buffer(s.index);
|
|
||||||
d.end_encoding(s.index);
|
|
||||||
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
|
|
||||||
d.commit_command_buffer(s.index);
|
|
||||||
d.get_command_buffer(s.index);
|
|
||||||
}
|
|
||||||
|
|
||||||
void synchronize(Stream s) {
|
void synchronize(Stream s) {
|
||||||
auto pool = new_scoped_memory_pool();
|
auto pool = new_scoped_memory_pool();
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include <future>
|
#include <future>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
@ -15,7 +16,10 @@ void new_stream(Stream stream);
|
|||||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||||
|
|
||||||
void eval(array& arr);
|
void eval(array& arr);
|
||||||
void finalize(Stream s);
|
void finalize(
|
||||||
|
Stream s,
|
||||||
|
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers,
|
||||||
|
bool force_commit);
|
||||||
void synchronize(Stream s);
|
void synchronize(Stream s);
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
@ -21,7 +21,10 @@ void eval(array&) {
|
|||||||
"[metal::eval] Cannot eval on GPU without metal backend");
|
"[metal::eval] Cannot eval on GPU without metal backend");
|
||||||
}
|
}
|
||||||
|
|
||||||
void finalize(Stream) {
|
void finalize(
|
||||||
|
Stream,
|
||||||
|
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers,
|
||||||
|
bool) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[metal::finalize] Cannot finalize GPU without metal backend");
|
"[metal::finalize] Cannot finalize GPU without metal backend");
|
||||||
}
|
}
|
||||||
|
@ -160,6 +160,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto it = cache.find(in.id());
|
auto it = cache.find(in.id());
|
||||||
|
|
||||||
it->second -= 1;
|
it->second -= 1;
|
||||||
|
|
||||||
if (it->second != 0) {
|
if (it->second != 0) {
|
||||||
@ -180,6 +181,8 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unordered_map<std::uintptr_t, std::weak_ptr<array::Data>>
|
||||||
|
unretained_buffers;
|
||||||
while (!tape.empty()) {
|
while (!tape.empty()) {
|
||||||
auto arr = std::move(tape.back());
|
auto arr = std::move(tape.back());
|
||||||
tape.pop_back();
|
tape.pop_back();
|
||||||
@ -225,7 +228,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
// Commit any open streams
|
// Commit any open streams
|
||||||
for (auto& [_, e] : events) {
|
for (auto& [_, e] : events) {
|
||||||
if (e.stream().device == Device::gpu) {
|
if (e.stream().device == Device::gpu) {
|
||||||
metal::finalize(e.stream());
|
metal::finalize(e.stream(), {}, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
scheduler::wait_for_one();
|
scheduler::wait_for_one();
|
||||||
@ -246,24 +249,59 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
arr.set_status(array::Status::evaluated);
|
arr.set_status(array::Status::evaluated);
|
||||||
// TODO Maybe always want the fence coherent kernel in the same cbuf
|
|
||||||
// as the other kernels?
|
std::unordered_set<std::shared_ptr<array::Data>> retain_buffers;
|
||||||
|
|
||||||
maybe_update_fence(arr);
|
maybe_update_fence(arr);
|
||||||
for (auto& sib : arr.siblings()) {
|
for (auto& sib : arr.siblings()) {
|
||||||
sib.set_status(array::Status::evaluated);
|
sib.set_status(array::Status::evaluated);
|
||||||
maybe_update_fence(sib);
|
maybe_update_fence(sib);
|
||||||
|
retain_buffers.insert(sib.data_shared_ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (auto& in : arr.inputs()) {
|
||||||
|
retain_buffers.insert(in.data_shared_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
if (!arr.is_tracer()) {
|
if (!arr.is_tracer()) {
|
||||||
arr.detach();
|
arr.detach();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
for (auto it = retain_buffers.begin(); it != retain_buffers.end();) {
|
||||||
|
if (it->use_count() > 1) {
|
||||||
|
// At this point the buffer must be in one of two states:
|
||||||
|
// 1. Held by another array
|
||||||
|
// 2. Held from a prevous async_eval
|
||||||
|
unretained_buffers.emplace(std::uintptr_t(it->get()), *it);
|
||||||
|
it = retain_buffers.erase(it);
|
||||||
|
} else {
|
||||||
|
unretained_buffers.erase(std::uintptr_t(it->get()));
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (stream.device == Device::gpu) {
|
||||||
|
metal::finalize(stream, std::move(retain_buffers), false);
|
||||||
|
} else {
|
||||||
|
cpu::finalize(stream, std::move(retain_buffers));
|
||||||
|
}
|
||||||
|
}
|
||||||
// Signal the event in its stream
|
// Signal the event in its stream
|
||||||
for (auto& [_, e] : events) {
|
for (auto& [_, e] : events) {
|
||||||
auto s = e.stream();
|
auto s = e.stream();
|
||||||
e.signal(s);
|
e.signal(s);
|
||||||
|
std::unordered_set<std::shared_ptr<array::Data>> retain;
|
||||||
|
if (s == stream) {
|
||||||
|
for (auto& [_, b] : unretained_buffers) {
|
||||||
|
auto ptr = b.lock();
|
||||||
|
if (ptr) {
|
||||||
|
retain.insert(ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
if (s.device == Device::gpu) {
|
if (s.device == Device::gpu) {
|
||||||
metal::finalize(s);
|
metal::finalize(s, std::move(retain), true);
|
||||||
|
} else {
|
||||||
|
cpu::finalize(s, std::move(retain));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -195,6 +195,36 @@ class TestEval(mlx_tests.MLXTestCase):
|
|||||||
mx.eval(z)
|
mx.eval(z)
|
||||||
mx.set_memory_limit(old_limit)
|
mx.set_memory_limit(old_limit)
|
||||||
|
|
||||||
|
def test_donation_multiple_inputs(self):
|
||||||
|
def fun(its, x, y):
|
||||||
|
for _ in range(its):
|
||||||
|
a = x + y # y should donate
|
||||||
|
b = x + a # x should donate
|
||||||
|
x, y = a, b
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
x = mx.zeros((128, 128))
|
||||||
|
y = mx.zeros((128, 128))
|
||||||
|
mx.reset_peak_memory()
|
||||||
|
a, b = fun(2, x, y)
|
||||||
|
mx.eval(a, b)
|
||||||
|
mx.synchronize()
|
||||||
|
mem2 = mx.get_peak_memory()
|
||||||
|
a, b = fun(10, x, y)
|
||||||
|
mx.eval(a, b)
|
||||||
|
mx.synchronize()
|
||||||
|
mem10 = mx.get_peak_memory()
|
||||||
|
self.assertEqual(mem2, mem10)
|
||||||
|
|
||||||
|
def test_async_with_delete(self):
|
||||||
|
a = mx.ones((5, 5))
|
||||||
|
for _ in range(100):
|
||||||
|
a = mx.abs(a)
|
||||||
|
mx.async_eval(a)
|
||||||
|
del a
|
||||||
|
mx.clear_cache()
|
||||||
|
mx.synchronize()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user