move all ObjC (via metal-cpp) interaction until post static initializers (#370)

* move all ObjC (via metal-cpp) interaction until post static initializers

- metal-cpp relies on static initializers to cache class and selector pointers
- code in mlx was using metal-cpp to set up NSAutoreleasePools during its own static init time
- but this code was silently failing as the class and selector pointers from metal-cpp were still nil

- defer the creation of NSAutoreleasePools until after static init time
- ensure that we have coverage where autorelease pools are needed

* Update device.cpp

remove commented code

* Update device.cpp

remove commented out code

* Update scheduler.h

update comment

* per discussion use the pool inside the task() -- this will be metal only, not needed for cpu

* Update allocator.cpp

move pool to release/alloc area
This commit is contained in:
davidkoski 2024-01-04 16:12:00 -08:00 committed by GitHub
parent 75dc537e44
commit c82a8cc526
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 6 deletions

View File

@ -29,6 +29,7 @@ BufferCache::BufferCache(MTL::Device* device)
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {}
BufferCache::~BufferCache() {
auto thread_pool = metal::new_scoped_memory_pool();
clear();
}
@ -166,6 +167,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
return Buffer{nullptr};
}
auto thread_pool = metal::new_scoped_memory_pool();
// If we have a lot of memory pressure, check if we can reclaim some memory
// from the cache
if (device_->currentAllocatedSize() + size >= gc_limit_) {

View File

@ -19,16 +19,14 @@ namespace mlx::core::metal {
namespace {
// Catch things related to the main-thread static variables
static std::shared_ptr<void> global_memory_pool = new_scoped_memory_pool();
// TODO nicer way to set this or possibly expose as an environment variable
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
static constexpr const char* default_mtllib_path = METAL_PATH;
auto load_device() {
MTL::Device* device = MTL::CreateSystemDefaultDevice();
auto devices = MTL::CopyAllDevices();
auto device = static_cast<MTL::Device*>(devices->object(0));
if (!device) {
throw std::runtime_error("Failed to load device");
}
@ -120,6 +118,7 @@ Device::Device() {
}
Device::~Device() {
auto pool = new_scoped_memory_pool();
for (auto& q : queue_map_) {
q.second->release();
}
@ -139,6 +138,8 @@ Device::~Device() {
}
void Device::new_queue(int index) {
auto thread_pool = metal::new_scoped_memory_pool();
// Multiple threads can ask the device for queues
// We lock this as a critical section for safety
const std::lock_guard<std::mutex> lock(mtx_);

View File

@ -35,8 +35,7 @@ struct StreamThread {
}
void thread_fn() {
auto thread_pool = metal::new_scoped_memory_pool();
metal::new_stream(stream);
bool initialized = false;
while (true) {
std::function<void()> task;
{
@ -48,6 +47,16 @@ struct StreamThread {
task = std::move(q.front());
q.pop();
}
// thread_fn may be called from a static initializer and we cannot
// call metal-cpp until all static initializers have completed. waiting
// for a task to arrive means that user code is running so metal-cpp
// can safely be called.
if (!initialized) {
initialized = true;
metal::new_stream(stream);
}
task();
}
}