mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
Add stats and limit to common allocator and enable tests (#1988)
* add stats to common allocator and enable tests * linux memory and default * fix
This commit is contained in:
@@ -4,7 +4,6 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::allocator {
|
||||
|
||||
@@ -22,23 +21,4 @@ void free(Buffer buffer) {
|
||||
allocator().free(buffer);
|
||||
}
|
||||
|
||||
Buffer CommonAllocator::malloc(size_t size) {
|
||||
void* ptr = std::malloc(size + sizeof(size_t));
|
||||
if (ptr != nullptr) {
|
||||
*static_cast<size_t*>(ptr) = size;
|
||||
}
|
||||
return Buffer{ptr};
|
||||
}
|
||||
|
||||
void CommonAllocator::free(Buffer buffer) {
|
||||
std::free(buffer.ptr());
|
||||
}
|
||||
|
||||
size_t CommonAllocator::size(Buffer buffer) const {
|
||||
if (buffer.ptr() == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
return *static_cast<size_t*>(buffer.ptr());
|
||||
}
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
|
@@ -49,16 +49,4 @@ class Allocator {
|
||||
|
||||
Allocator& allocator();
|
||||
|
||||
class CommonAllocator : public Allocator {
|
||||
/** A general CPU allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
|
||||
private:
|
||||
CommonAllocator() = default;
|
||||
friend Allocator& allocator();
|
||||
};
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
|
@@ -1,38 +1,123 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <mutex>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
|
||||
#ifdef __APPLE__
|
||||
#include "mlx/backend/no_metal/apple_memory.h"
|
||||
#elif defined(__linux__)
|
||||
#include "mlx/backend/no_metal/linux_memory.h"
|
||||
#else
|
||||
size_t get_memory_size() {
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace allocator {
|
||||
|
||||
Allocator& allocator() {
|
||||
class CommonAllocator : public Allocator {
|
||||
/** A general CPU allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
size_t get_active_memory() const {
|
||||
return active_memory_;
|
||||
};
|
||||
size_t get_peak_memory() const {
|
||||
return peak_memory_;
|
||||
};
|
||||
void reset_peak_memory() {
|
||||
std::unique_lock lk(mutex_);
|
||||
peak_memory_ = 0;
|
||||
};
|
||||
size_t get_memory_limit() {
|
||||
return memory_limit_;
|
||||
}
|
||||
size_t set_memory_limit(size_t limit) {
|
||||
std::unique_lock lk(mutex_);
|
||||
std::swap(memory_limit_, limit);
|
||||
return limit;
|
||||
}
|
||||
|
||||
private:
|
||||
size_t memory_limit_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
std::mutex mutex_;
|
||||
CommonAllocator() : memory_limit_(0.8 * get_memory_size()) {
|
||||
if (memory_limit_ == 0) {
|
||||
memory_limit_ = 1UL << 33;
|
||||
}
|
||||
};
|
||||
|
||||
friend CommonAllocator& common_allocator();
|
||||
};
|
||||
|
||||
CommonAllocator& common_allocator() {
|
||||
static CommonAllocator allocator_;
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
Allocator& allocator() {
|
||||
return common_allocator();
|
||||
}
|
||||
|
||||
void* Buffer::raw_ptr() {
|
||||
if (!ptr_) {
|
||||
return nullptr;
|
||||
}
|
||||
return static_cast<size_t*>(ptr_) + 1;
|
||||
}
|
||||
|
||||
Buffer CommonAllocator::malloc(size_t size) {
|
||||
void* ptr = std::malloc(size + sizeof(size_t));
|
||||
if (ptr != nullptr) {
|
||||
*static_cast<size_t*>(ptr) = size;
|
||||
}
|
||||
std::unique_lock lk(mutex_);
|
||||
active_memory_ += size;
|
||||
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||
return Buffer{ptr};
|
||||
}
|
||||
|
||||
void CommonAllocator::free(Buffer buffer) {
|
||||
auto sz = size(buffer);
|
||||
std::free(buffer.ptr());
|
||||
std::unique_lock lk(mutex_);
|
||||
active_memory_ -= sz;
|
||||
}
|
||||
|
||||
size_t CommonAllocator::size(Buffer buffer) const {
|
||||
if (buffer.ptr() == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
return *static_cast<size_t*>(buffer.ptr());
|
||||
}
|
||||
|
||||
} // namespace allocator
|
||||
|
||||
size_t get_active_memory() {
|
||||
return 0;
|
||||
return allocator::common_allocator().get_active_memory();
|
||||
}
|
||||
size_t get_peak_memory() {
|
||||
return 0;
|
||||
return allocator::common_allocator().get_peak_memory();
|
||||
}
|
||||
void reset_peak_memory() {}
|
||||
size_t get_cache_memory() {
|
||||
return 0;
|
||||
void reset_peak_memory() {
|
||||
return allocator::common_allocator().reset_peak_memory();
|
||||
}
|
||||
size_t set_memory_limit(size_t) {
|
||||
return 0;
|
||||
size_t set_memory_limit(size_t limit) {
|
||||
return allocator::common_allocator().set_memory_limit(limit);
|
||||
}
|
||||
size_t get_memory_limit() {
|
||||
return allocator::common_allocator().get_memory_limit();
|
||||
}
|
||||
|
||||
// No-ops for common allocator
|
||||
size_t get_cache_memory() {
|
||||
return 0;
|
||||
}
|
||||
size_t set_cache_limit(size_t) {
|
||||
|
16
mlx/backend/no_metal/apple_memory.h
Normal file
16
mlx/backend/no_metal/apple_memory.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sys/sysctl.h>
|
||||
|
||||
namespace {
|
||||
|
||||
size_t get_memory_size() {
|
||||
size_t memsize = 0;
|
||||
size_t length = sizeof(memsize);
|
||||
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
|
||||
return memsize;
|
||||
}
|
||||
|
||||
} // namespace
|
22
mlx/backend/no_metal/linux_memory.h
Normal file
22
mlx/backend/no_metal/linux_memory.h
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sys/sysinfo.h>
|
||||
|
||||
namespace {
|
||||
|
||||
size_t get_memory_size() {
|
||||
struct sysinfo info;
|
||||
|
||||
if (sysinfo(&info) != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t total_ram = info.totalram;
|
||||
total_ram *= info.mem_unit;
|
||||
|
||||
return total_ram;
|
||||
}
|
||||
|
||||
} // namespace
|
Reference in New Issue
Block a user