diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 528e1db76..aad031cfa 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -140,10 +140,15 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) { MetalAllocator::MetalAllocator() : device_(device(mlx::core::Device::gpu).mtl_device()), - buffer_cache_(device_), - block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()), - gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()), - max_pool_size_(block_limit_) {} + buffer_cache_(device_) { + auto memsize = std::get(device_info()["memory_size"]); + block_limit_ = + std::min(1.5 * device_->recommendedMaxWorkingSetSize(), 0.95 * memsize); + gc_limit_ = std::min( + static_cast(0.95 * device_->recommendedMaxWorkingSetSize()), + block_limit_); + max_pool_size_ = block_limit_; +} size_t MetalAllocator::set_cache_limit(size_t limit) { std::swap(limit, max_pool_size_); diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 059e12b01..155fdf356 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -5,6 +5,8 @@ #include #include +#include + #define NS_PRIVATE_IMPLEMENTATION #define CA_PRIVATE_IMPLEMENTATION #define MTL_PRIVATE_IMPLEMENTATION @@ -560,11 +562,19 @@ std::unordered_map> device_info() { auto raw_device = device(default_device()).mtl_device(); auto arch = std::string(raw_device->architecture()->name()->utf8String()); + + int mib[] = {CTL_HW, HW_MEMSIZE}; + size_t memsize = 0; + size_t length = sizeof(memsize); + + sysctl(mib, 2, &memsize, &length, NULL, 0); + return { {"architecture", arch}, {"max_buffer_length", raw_device->maxBufferLength()}, {"max_recommended_working_set_size", - raw_device->recommendedMaxWorkingSetSize()}}; + raw_device->recommendedMaxWorkingSetSize()}, + {"memory_size", memsize}}; } } // namespace mlx::core::metal diff --git a/python/src/metal.cpp b/python/src/metal.cpp index b29255e2a..fef2cc69a 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -129,6 +129,7 @@ void init_metal(nb::module_& m) { * ``architecture`` * ``max_buffer_size`` * ``max_recommended_working_set_size`` + * ``memory_size`` Returns: dict: A dictionary with string keys and string or integer values.