Add stats and limit to common allocator and enable tests (#1988)

* add stats to common allocator and enable tests

* linux memory and default

* fix
This commit is contained in:
Awni Hannun
2025-03-21 12:28:36 -07:00
committed by GitHub
parent d343782c8b
commit 2a980a76ce
13 changed files with 151 additions and 68 deletions

View File

@@ -4,7 +4,6 @@
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/scheduler.h"
namespace mlx::core::allocator {
@@ -22,23 +21,4 @@ void free(Buffer buffer) {
allocator().free(buffer);
}
Buffer CommonAllocator::malloc(size_t size) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
}
void CommonAllocator::free(Buffer buffer) {
std::free(buffer.ptr());
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}
} // namespace mlx::core::allocator

View File

@@ -49,16 +49,4 @@ class Allocator {
Allocator& allocator();
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
private:
CommonAllocator() = default;
friend Allocator& allocator();
};
} // namespace mlx::core::allocator

View File

@@ -1,38 +1,123 @@
// Copyright © 2023 Apple Inc.
#include <mutex>
#include "mlx/allocator.h"
#ifdef __APPLE__
#include "mlx/backend/no_metal/apple_memory.h"
#elif defined(__linux__)
#include "mlx/backend/no_metal/linux_memory.h"
#else
size_t get_memory_size() {
return 0;
}
#endif
namespace mlx::core {
namespace allocator {
Allocator& allocator() {
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() const {
return active_memory_;
};
size_t get_peak_memory() const {
return peak_memory_;
};
void reset_peak_memory() {
std::unique_lock lk(mutex_);
peak_memory_ = 0;
};
size_t get_memory_limit() {
return memory_limit_;
}
size_t set_memory_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(memory_limit_, limit);
return limit;
}
private:
size_t memory_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
std::mutex mutex_;
CommonAllocator() : memory_limit_(0.8 * get_memory_size()) {
if (memory_limit_ == 0) {
memory_limit_ = 1UL << 33;
}
};
friend CommonAllocator& common_allocator();
};
CommonAllocator& common_allocator() {
static CommonAllocator allocator_;
return allocator_;
}
Allocator& allocator() {
return common_allocator();
}
void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<size_t*>(ptr_) + 1;
}
Buffer CommonAllocator::malloc(size_t size) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
std::unique_lock lk(mutex_);
active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_);
return Buffer{ptr};
}
void CommonAllocator::free(Buffer buffer) {
auto sz = size(buffer);
std::free(buffer.ptr());
std::unique_lock lk(mutex_);
active_memory_ -= sz;
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}
} // namespace allocator
size_t get_active_memory() {
return 0;
return allocator::common_allocator().get_active_memory();
}
size_t get_peak_memory() {
return 0;
return allocator::common_allocator().get_peak_memory();
}
void reset_peak_memory() {}
size_t get_cache_memory() {
return 0;
void reset_peak_memory() {
return allocator::common_allocator().reset_peak_memory();
}
size_t set_memory_limit(size_t) {
return 0;
size_t set_memory_limit(size_t limit) {
return allocator::common_allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return allocator::common_allocator().get_memory_limit();
}
// No-ops for common allocator
size_t get_cache_memory() {
return 0;
}
size_t set_cache_limit(size_t) {

View File

@@ -0,0 +1,16 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <sys/sysctl.h>
namespace {
size_t get_memory_size() {
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
return memsize;
}
} // namespace

View File

@@ -0,0 +1,22 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <sys/sysinfo.h>
namespace {
size_t get_memory_size() {
struct sysinfo info;
if (sysinfo(&info) != 0) {
return 0;
}
size_t total_ram = info.totalram;
total_ram *= info.mem_unit;
return total_ram;
}
} // namespace