make things simpler

This commit is contained in:
Ronan Collobert 2023-12-21 16:22:52 -08:00
parent a813bdda0a
commit 29a8b2047b
5 changed files with 55 additions and 60 deletions

View File

@ -17,10 +17,11 @@ namespace fs = std::filesystem;
namespace mlx::core::metal {
static Device metal_device_;
namespace {
// Catch things related to the main-thread static variables
static std::shared_ptr<void> global_memory_pool = new_memory_pool();
// TODO nicer way to set this or possibly expose as an environment variable
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
@ -112,29 +113,29 @@ MTL::Library* load_library(
} // namespace
Device::Device()
: pool_(NS::AutoreleasePool::alloc()->init()),
device_(load_device()),
library_map_({{"mlx", load_library(device_)}}) {}
Device::Device() {
auto pool = new_memory_pool();
device_ = load_device();
library_map_ = {{"mlx", load_library(device_)}};
}
Device::~Device() {
for (auto& q : queue_map_) {
q.second->release();
}
for (auto& k : kernel_map_) {
k.second->release();
}
for (auto& l : library_map_) {
l.second->release();
}
for (auto& b : buffer_map_) {
b.second.second->release();
}
for (auto& e : encoder_map_) {
e.second->release();
}
for (auto& k : kernel_map_) {
k.second->release();
}
for (auto& l : library_map_) {
l.second->release();
}
device_->release();
pool_->release();
}
void Device::new_queue(int index) {
@ -243,6 +244,7 @@ void Device::register_library(
MTL::ComputePipelineState* Device::get_kernel(
const std::string& name,
const std::string& lib_name /* = "mlx" */) {
auto pool = new_memory_pool();
// Look for cached kernel
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
return it->second;
@ -285,18 +287,19 @@ MTL::ComputePipelineState* Device::get_kernel(
}
Device& device(mlx::core::Device) {
return metal_device_;
static Device metal_device;
return metal_device;
}
NS::AutoreleasePool*& Device::g_thread_autorelease_pool() {
static thread_local NS::AutoreleasePool* p =
NS::AutoreleasePool::alloc()->init();
return p;
std::shared_ptr<void> new_memory_pool() {
auto dtor = [](void* ptr) {
static_cast<NS::AutoreleasePool*>(ptr)->release();
};
return std::shared_ptr<void>(NS::AutoreleasePool::alloc()->init(), dtor);
}
void new_stream(Stream stream) {
if (stream.device == mlx::core::Device::gpu) {
device(stream.device).g_thread_autorelease_pool();
device(stream.device).new_queue(stream.index);
}
}

View File

@ -66,10 +66,7 @@ class Device {
MTL::ArgumentEncoder* argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
NS::AutoreleasePool*& g_thread_autorelease_pool();
private:
NS::AutoreleasePool* pool_;
MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;

View File

@ -48,44 +48,37 @@ std::function<void()> make_task(
std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p,
bool retain_graph) {
auto task = [retain_graph,
arr,
deps = std::move(deps),
p = std::move(p)]() mutable {
for (auto& d : deps) {
d.wait();
}
auto s = arr.primitive().stream();
auto command_buffer = increment_command_buffer(s);
arr.primitive().eval_gpu(arr.inputs(), arr);
if (p) {
metal::device(s.device).end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[retain_graph, s, arr, p = std::move(p)](
MTL::CommandBuffer*) mutable {
if (!retain_graph) {
arr.detach();
}
p->set_value();
// Signal this thread to clear the pool on a synchroniztion.
scheduler::enqueue(s, [s]() {
metal::device(s.device).g_thread_autorelease_pool()->release();
metal::device(s.device).g_thread_autorelease_pool() =
NS::AutoreleasePool::alloc()->init();
});
scheduler::notify_task_completion(s);
});
metal::device(s.device).commit_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
if (!retain_graph) {
arr.detach();
}
});
}
};
auto task =
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
auto pool = new_memory_pool();
for (auto& d : deps) {
d.wait();
}
auto s = arr.primitive().stream();
auto command_buffer = increment_command_buffer(s);
arr.primitive().eval_gpu(arr.inputs(), arr);
if (p) {
metal::device(s.device).end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[retain_graph, s, arr, p = std::move(p)](
MTL::CommandBuffer*) mutable {
if (!retain_graph) {
arr.detach();
}
p->set_value();
scheduler::notify_task_completion(s);
});
metal::device(s.device).commit_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
if (!retain_graph) {
arr.detach();
}
});
}
};
return task;
}

View File

@ -20,6 +20,7 @@ constexpr bool is_available() {
}
void new_stream(Stream stream);
std::shared_ptr<void> new_memory_pool();
std::function<void()> make_task(
array& arr,

View File

@ -35,6 +35,7 @@ struct StreamThread {
}
void thread_fn() {
auto thread_pool = metal::new_memory_pool();
metal::new_stream(stream);
while (true) {
std::function<void()> task;