diff --git a/mlx/backend/metal/resident.cpp b/mlx/backend/metal/resident.cpp index 403857c6d..ec2560aa4 100644 --- a/mlx/backend/metal/resident.cpp +++ b/mlx/backend/metal/resident.cpp @@ -5,11 +5,10 @@ namespace mlx::core::metal { -// TODO maybe worth including tvos / visionos -#define supported __builtin_available(macOS 15, iOS 18, *) - ResidencySet::ResidencySet(MTL::Device* d) { - if (supported) { + if (!d->supportsFamily(MTL::GPUFamilyMetal3)) { + return; + } else if (__builtin_available(macOS 15, iOS 18, *)) { auto pool = new_scoped_memory_pool(); auto desc = MTL::ResidencySetDescriptor::alloc()->init(); NS::Error* error; @@ -27,68 +26,72 @@ ResidencySet::ResidencySet(MTL::Device* d) { } void ResidencySet::insert(MTL::Allocation* buf) { - if (supported) { - if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) { - wired_set_->addAllocation(buf); - wired_set_->commit(); - wired_set_->requestResidency(); - } else { - unwired_set_.insert(buf); - } + if (!wired_set_) { + return; + } + if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) { + wired_set_->addAllocation(buf); + wired_set_->commit(); + wired_set_->requestResidency(); + } else { + unwired_set_.insert(buf); } } void ResidencySet::erase(MTL::Allocation* buf) { - if (supported) { - if (auto it = unwired_set_.find(buf); it != unwired_set_.end()) { - unwired_set_.erase(it); - } else { - wired_set_->removeAllocation(buf); - wired_set_->commit(); - } + if (!wired_set_) { + return; + } + if (auto it = unwired_set_.find(buf); it != unwired_set_.end()) { + unwired_set_.erase(it); + } else { + wired_set_->removeAllocation(buf); + wired_set_->commit(); } } void ResidencySet::resize(size_t size) { - if (supported) { - if (capacity_ == size) { - return; - } - capacity_ = size; + if (!wired_set_) { + return; + } - size_t current_size = wired_set_->allocatedSize(); + if (capacity_ == size) { + return; + } + capacity_ = size; - if (current_size < size) { - // Add unwired allocations to the set - for (auto it = unwired_set_.begin(); it != unwired_set_.end();) { - auto buf_size = (*it)->allocatedSize(); - if (current_size + buf_size > size) { - it++; - } else { - current_size += buf_size; - wired_set_->addAllocation(*it); - unwired_set_.erase(it++); - } + size_t current_size = wired_set_->allocatedSize(); + + if (current_size < size) { + // Add unwired allocations to the set + for (auto it = unwired_set_.begin(); it != unwired_set_.end();) { + auto buf_size = (*it)->allocatedSize(); + if (current_size + buf_size > size) { + it++; + } else { + current_size += buf_size; + wired_set_->addAllocation(*it); + unwired_set_.erase(it++); } - wired_set_->commit(); - wired_set_->requestResidency(); - } else if (current_size > size) { - // Remove wired allocations until under capacity - auto allocations = wired_set_->allAllocations(); - auto num_allocations = wired_set_->allocationCount(); - for (int i = 0; i < num_allocations && current_size > size; ++i) { - auto buf = static_cast(allocations->object(i)); - wired_set_->removeAllocation(buf); - current_size -= buf->allocatedSize(); - unwired_set_.insert(buf); - } - wired_set_->commit(); } + wired_set_->commit(); + wired_set_->requestResidency(); + } else if (current_size > size) { + // Remove wired allocations until under capacity + auto allocations = wired_set_->allAllocations(); + auto num_allocations = wired_set_->allocationCount(); + for (int i = 0; i < num_allocations && current_size > size; ++i) { + auto buf = static_cast(allocations->object(i)); + wired_set_->removeAllocation(buf); + current_size -= buf->allocatedSize(); + unwired_set_.insert(buf); + } + wired_set_->commit(); } } ResidencySet::~ResidencySet() { - if (supported) { + if (wired_set_) { wired_set_->release(); } }