make things simpler

This commit is contained in:
Ronan Collobert 2023-12-21 16:22:52 -08:00
parent a813bdda0a
commit 29a8b2047b
5 changed files with 55 additions and 60 deletions

View File

@ -17,10 +17,11 @@ namespace fs = std::filesystem;
namespace mlx::core::metal { namespace mlx::core::metal {
static Device metal_device_;
namespace { namespace {
// Catch things related to the main-thread static variables
static std::shared_ptr<void> global_memory_pool = new_memory_pool();
// TODO nicer way to set this or possibly expose as an environment variable // TODO nicer way to set this or possibly expose as an environment variable
static constexpr int MAX_BUFFERS_PER_QUEUE = 12; static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
@ -112,29 +113,29 @@ MTL::Library* load_library(
} // namespace } // namespace
Device::Device() Device::Device() {
: pool_(NS::AutoreleasePool::alloc()->init()), auto pool = new_memory_pool();
device_(load_device()), device_ = load_device();
library_map_({{"mlx", load_library(device_)}}) {} library_map_ = {{"mlx", load_library(device_)}};
}
Device::~Device() { Device::~Device() {
for (auto& q : queue_map_) { for (auto& q : queue_map_) {
q.second->release(); q.second->release();
} }
for (auto& k : kernel_map_) {
k.second->release();
}
for (auto& l : library_map_) {
l.second->release();
}
for (auto& b : buffer_map_) { for (auto& b : buffer_map_) {
b.second.second->release(); b.second.second->release();
} }
for (auto& e : encoder_map_) { for (auto& e : encoder_map_) {
e.second->release(); e.second->release();
} }
for (auto& k : kernel_map_) {
k.second->release();
}
for (auto& l : library_map_) {
l.second->release();
}
device_->release(); device_->release();
pool_->release();
} }
void Device::new_queue(int index) { void Device::new_queue(int index) {
@ -243,6 +244,7 @@ void Device::register_library(
MTL::ComputePipelineState* Device::get_kernel( MTL::ComputePipelineState* Device::get_kernel(
const std::string& name, const std::string& name,
const std::string& lib_name /* = "mlx" */) { const std::string& lib_name /* = "mlx" */) {
auto pool = new_memory_pool();
// Look for cached kernel // Look for cached kernel
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) { if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
return it->second; return it->second;
@ -285,18 +287,19 @@ MTL::ComputePipelineState* Device::get_kernel(
} }
Device& device(mlx::core::Device) { Device& device(mlx::core::Device) {
return metal_device_; static Device metal_device;
return metal_device;
} }
NS::AutoreleasePool*& Device::g_thread_autorelease_pool() { std::shared_ptr<void> new_memory_pool() {
static thread_local NS::AutoreleasePool* p = auto dtor = [](void* ptr) {
NS::AutoreleasePool::alloc()->init(); static_cast<NS::AutoreleasePool*>(ptr)->release();
return p; };
return std::shared_ptr<void>(NS::AutoreleasePool::alloc()->init(), dtor);
} }
void new_stream(Stream stream) { void new_stream(Stream stream) {
if (stream.device == mlx::core::Device::gpu) { if (stream.device == mlx::core::Device::gpu) {
device(stream.device).g_thread_autorelease_pool();
device(stream.device).new_queue(stream.index); device(stream.device).new_queue(stream.index);
} }
} }

View File

@ -66,10 +66,7 @@ class Device {
MTL::ArgumentEncoder* argument_encoder( MTL::ArgumentEncoder* argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const; const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
NS::AutoreleasePool*& g_thread_autorelease_pool();
private: private:
NS::AutoreleasePool* pool_;
MTL::Device* device_; MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_; std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_; std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;

View File

@ -48,44 +48,37 @@ std::function<void()> make_task(
std::vector<std::shared_future<void>> deps, std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p, std::shared_ptr<std::promise<void>> p,
bool retain_graph) { bool retain_graph) {
auto task = [retain_graph, auto task =
arr, [retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
deps = std::move(deps), auto pool = new_memory_pool();
p = std::move(p)]() mutable { for (auto& d : deps) {
for (auto& d : deps) { d.wait();
d.wait(); }
} auto s = arr.primitive().stream();
auto s = arr.primitive().stream(); auto command_buffer = increment_command_buffer(s);
auto command_buffer = increment_command_buffer(s); arr.primitive().eval_gpu(arr.inputs(), arr);
arr.primitive().eval_gpu(arr.inputs(), arr); if (p) {
if (p) { metal::device(s.device).end_encoding(s.index);
metal::device(s.device).end_encoding(s.index); scheduler::notify_new_task(s);
scheduler::notify_new_task(s); command_buffer->addCompletedHandler(
command_buffer->addCompletedHandler( [retain_graph, s, arr, p = std::move(p)](
[retain_graph, s, arr, p = std::move(p)]( MTL::CommandBuffer*) mutable {
MTL::CommandBuffer*) mutable { if (!retain_graph) {
if (!retain_graph) { arr.detach();
arr.detach(); }
} p->set_value();
p->set_value(); scheduler::notify_task_completion(s);
// Signal this thread to clear the pool on a synchroniztion. });
scheduler::enqueue(s, [s]() { metal::device(s.device).commit_command_buffer(s.index);
metal::device(s.device).g_thread_autorelease_pool()->release(); } else {
metal::device(s.device).g_thread_autorelease_pool() = command_buffer->addCompletedHandler(
NS::AutoreleasePool::alloc()->init(); [retain_graph, s, arr](MTL::CommandBuffer*) mutable {
}); if (!retain_graph) {
scheduler::notify_task_completion(s); arr.detach();
}); }
metal::device(s.device).commit_command_buffer(s.index); });
} else { }
command_buffer->addCompletedHandler( };
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
if (!retain_graph) {
arr.detach();
}
});
}
};
return task; return task;
} }

View File

@ -20,6 +20,7 @@ constexpr bool is_available() {
} }
void new_stream(Stream stream); void new_stream(Stream stream);
std::shared_ptr<void> new_memory_pool();
std::function<void()> make_task( std::function<void()> make_task(
array& arr, array& arr,

View File

@ -35,6 +35,7 @@ struct StreamThread {
} }
void thread_fn() { void thread_fn() {
auto thread_pool = metal::new_memory_pool();
metal::new_stream(stream); metal::new_stream(stream);
while (true) { while (true) {
std::function<void()> task; std::function<void()> task;