mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-20 20:18:15 +08:00
bindings for memory info (#761)
* bindings for memory info * update api * keep cache low if requested * fix default * nit in ops error
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/allocator.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
@@ -23,16 +23,6 @@ void* Buffer::raw_ptr() {
|
||||
|
||||
namespace metal {
|
||||
|
||||
static bool cache_enabled_ = true;
|
||||
|
||||
bool cache_enabled() {
|
||||
return cache_enabled_;
|
||||
}
|
||||
|
||||
void set_cache_enabled(bool enabled) {
|
||||
cache_enabled_ = enabled;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
BufferCache::BufferCache(MTL::Device* device)
|
||||
@@ -158,9 +148,23 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
||||
MetalAllocator::MetalAllocator()
|
||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||
buffer_cache_(device_),
|
||||
peak_allocated_size_(0),
|
||||
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()),
|
||||
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {}
|
||||
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()),
|
||||
max_pool_size_(block_limit_) {}
|
||||
|
||||
size_t MetalAllocator::set_cache_limit(size_t limit) {
|
||||
std::swap(limit, max_pool_size_);
|
||||
return limit;
|
||||
};
|
||||
|
||||
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
|
||||
std::swap(limit, block_limit_);
|
||||
relaxed_ = relaxed;
|
||||
gc_limit_ = std::min(
|
||||
block_limit_,
|
||||
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));
|
||||
return limit;
|
||||
};
|
||||
|
||||
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
// Metal doesn't like empty buffers
|
||||
@@ -175,10 +179,12 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
|
||||
// Try the cache
|
||||
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
|
||||
size_t pool_size = get_cache_memory();
|
||||
if (!buf) {
|
||||
size_t mem_required = get_active_memory() + pool_size + size;
|
||||
|
||||
// If there is too much memory pressure, fail (likely causes a wait).
|
||||
if (!allow_swap && device_->currentAllocatedSize() + size >= block_limit_) {
|
||||
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
|
||||
return Buffer{nullptr};
|
||||
}
|
||||
|
||||
@@ -186,10 +192,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
|
||||
// If we have a lot of memory pressure, check if we can reclaim some memory
|
||||
// from the cache
|
||||
if (device_->currentAllocatedSize() + size >= gc_limit_) {
|
||||
size_t min_bytes_to_free =
|
||||
size + device_->currentAllocatedSize() - gc_limit_;
|
||||
buffer_cache_.release_cached_buffers(min_bytes_to_free);
|
||||
if (mem_required >= gc_limit_) {
|
||||
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
||||
}
|
||||
|
||||
// Allocate new buffer if needed
|
||||
@@ -198,15 +202,22 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
buf = device_->newBuffer(size, res_opt);
|
||||
}
|
||||
|
||||
peak_allocated_size_ =
|
||||
std::max(peak_allocated_size_, device_->currentAllocatedSize());
|
||||
// Maintain the cache below the requested limit
|
||||
if (pool_size >= max_pool_size_) {
|
||||
auto thread_pool = metal::new_scoped_memory_pool();
|
||||
buffer_cache_.release_cached_buffers(pool_size - max_pool_size_);
|
||||
}
|
||||
|
||||
active_memory_ += buf->length();
|
||||
peak_memory_ = std::max(peak_memory_, active_memory_);
|
||||
|
||||
return Buffer{static_cast<void*>(buf)};
|
||||
}
|
||||
|
||||
void MetalAllocator::free(Buffer buffer) {
|
||||
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||
if (cache_enabled()) {
|
||||
active_memory_ -= buf->length();
|
||||
if (max_pool_size_ > 0) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
buf->release();
|
||||
@@ -218,6 +229,22 @@ MetalAllocator& allocator() {
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
size_t set_cache_limit(size_t limit) {
|
||||
return allocator().set_cache_limit(limit);
|
||||
}
|
||||
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
|
||||
return allocator().set_memory_limit(limit, relaxed);
|
||||
}
|
||||
size_t get_active_memory() {
|
||||
return allocator().get_active_memory();
|
||||
}
|
||||
size_t get_peak_memory() {
|
||||
return allocator().get_peak_memory();
|
||||
}
|
||||
size_t get_cache_memory() {
|
||||
return allocator().get_cache_memory();
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -24,6 +24,9 @@ class BufferCache {
|
||||
MTL::Buffer* reuse_from_cache(size_t size);
|
||||
void recycle_to_cache(MTL::Buffer* buf);
|
||||
void release_cached_buffers(size_t min_bytes_to_free);
|
||||
size_t pool_size() {
|
||||
return pool_size_;
|
||||
}
|
||||
|
||||
private:
|
||||
struct BufferHolder {
|
||||
@@ -54,6 +57,17 @@ class MetalAllocator : public allocator::Allocator {
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
size_t get_active_memory() {
|
||||
return active_memory_;
|
||||
};
|
||||
size_t get_peak_memory() {
|
||||
return peak_memory_;
|
||||
};
|
||||
size_t get_cache_memory() {
|
||||
return buffer_cache_.pool_size();
|
||||
};
|
||||
size_t set_cache_limit(size_t limit);
|
||||
size_t set_memory_limit(size_t limit, bool relaxed);
|
||||
|
||||
private:
|
||||
MTL::Device* device_;
|
||||
@@ -64,9 +78,12 @@ class MetalAllocator : public allocator::Allocator {
|
||||
BufferCache buffer_cache_;
|
||||
|
||||
// Allocation stats
|
||||
size_t peak_allocated_size_;
|
||||
size_t block_limit_;
|
||||
size_t gc_limit_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
size_t max_pool_size_;
|
||||
bool relaxed_{true};
|
||||
};
|
||||
|
||||
MetalAllocator& allocator();
|
||||
|
@@ -12,8 +12,51 @@
|
||||
namespace mlx::core::metal {
|
||||
|
||||
bool is_available();
|
||||
bool cache_enabled(void);
|
||||
void set_cache_enabled(bool enabled);
|
||||
|
||||
/* Get the actively used memory in bytes.
|
||||
*
|
||||
* Note, this will not always match memory use reported by the system because
|
||||
* it does not include cached memory buffers.
|
||||
* */
|
||||
size_t get_active_memory();
|
||||
|
||||
/* Get the peak amount of used memory in bytes.
|
||||
*
|
||||
* The maximum memory used is recorded from the beginning of the program
|
||||
* execution.
|
||||
* */
|
||||
size_t get_peak_memory();
|
||||
|
||||
/* Get the cache size in bytes.
|
||||
*
|
||||
* The cache includes memory not currently used that has not been returned
|
||||
* to the system allocator.
|
||||
* */
|
||||
size_t get_cache_memory();
|
||||
|
||||
/* Set the memory limit.
|
||||
* Calls to malloc will wait on scheduled tasks if the limit is exceeded. If
|
||||
* there are no more scheduled tasks an error will be raised if relaxed
|
||||
* is false or memory will be allocated (including the potential for
|
||||
* swap) if relaxed is true.
|
||||
*
|
||||
* The memory limit defaults to 1.5 times the maximum recommended working set
|
||||
* size reported by the device.
|
||||
*
|
||||
* Returns the previous memory limit.
|
||||
* */
|
||||
size_t set_memory_limit(size_t limit, bool relaxed = true);
|
||||
|
||||
/* Set the free cache limit.
|
||||
* If using more than the given limit, free memory will be reclaimed
|
||||
* from the cache on the next allocation. To disable the cache,
|
||||
* set the limit to 0.
|
||||
*
|
||||
* The cache limit defaults to the memory limit.
|
||||
*
|
||||
* Returns the previous cache limit.
|
||||
* */
|
||||
size_t set_cache_limit(size_t limit);
|
||||
|
||||
void new_stream(Stream stream);
|
||||
std::shared_ptr<void> new_scoped_memory_pool();
|
||||
|
@@ -23,10 +23,21 @@ std::function<void()> make_task(
|
||||
"[metal::make_task] Cannot make GPU task without metal backend");
|
||||
}
|
||||
|
||||
// No cache for CPU only
|
||||
bool cache_enabled(void) {
|
||||
return false;
|
||||
// No-ops when Metal is not available.
|
||||
size_t get_active_memory() {
|
||||
return 0;
|
||||
}
|
||||
size_t get_peak_memory() {
|
||||
return 0;
|
||||
}
|
||||
size_t get_cache_memory() {
|
||||
return 0;
|
||||
}
|
||||
size_t set_memory_limit(size_t, bool) {
|
||||
return 0;
|
||||
}
|
||||
size_t set_cache_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
void set_cache_enabled(bool) {}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -1700,7 +1700,7 @@ array argpartition(
|
||||
int kth_ = kth < 0 ? kth + a.shape(axis) : kth;
|
||||
if (kth_ < 0 || kth_ >= a.shape(axis_)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[argpartition] Received invalid kth " << kth << "along axis "
|
||||
msg << "[argpartition] Received invalid kth " << kth << " along axis "
|
||||
<< axis << " for array with shape: " << a.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
Reference in New Issue
Block a user