mlx/mlx/backend/metal/allocator.h
Awni Hannun 0eb56d5be0
Wired (#1510)
* expose residency sets as wire/unwire

* returns wired size

* fix

* runtime support check

* fix os check

* fix test

* fix no metal build

* docs

* nit

* nits in docs

* nits
2024-10-25 09:35:33 -07:00

104 lines
2.2 KiB
C++

// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <map>
#include <mutex>
#include <vector>
#include "mlx/allocator.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/resident.h"
namespace mlx::core::metal {
using allocator::Buffer;
namespace {
class BufferCache {
public:
BufferCache(MTL::Device* device);
~BufferCache();
MTL::Buffer* reuse_from_cache(size_t size);
void recycle_to_cache(MTL::Buffer* buf);
void release_cached_buffers(size_t min_bytes_to_free);
size_t cache_size() {
return pool_size_;
}
void clear();
private:
struct BufferHolder {
public:
BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {}
BufferHolder* prev;
BufferHolder* next;
MTL::Buffer* buf;
};
void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove);
MTL::Device* device_;
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_;
BufferHolder* tail_;
size_t pool_size_;
};
} // namespace
class MetalAllocator : public allocator::Allocator {
/** Allocator for Metal GPUs. */
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() {
return active_memory_;
};
size_t get_peak_memory() {
return peak_memory_;
};
void reset_peak_memory() {
std::unique_lock lk(mutex_);
peak_memory_ = 0;
};
size_t get_cache_memory() {
return buffer_cache_.cache_size();
};
size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed);
size_t set_wired_limit(size_t limit);
void clear_cache();
private:
MTL::Device* device_;
MetalAllocator();
friend MetalAllocator& allocator();
// Caching allocator
BufferCache buffer_cache_;
ResidencySet residency_set_;
// Allocation stats
size_t block_limit_;
size_t gc_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
size_t max_pool_size_;
size_t wired_limit_{0};
bool relaxed_{true};
std::mutex mutex_;
};
MetalAllocator& allocator();
} // namespace mlx::core::metal