Add stats and limit to common allocator and enable tests (#1988)

* add stats to common allocator and enable tests

* linux memory and default

* fix
This commit is contained in:
Awni Hannun 2025-03-21 12:28:36 -07:00 committed by GitHub
parent d343782c8b
commit 2a980a76ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 151 additions and 68 deletions

View File

@ -4,7 +4,6 @@
#include <sstream>
#include "mlx/allocator.h"
#include "mlx/scheduler.h"
namespace mlx::core::allocator {
@ -22,23 +21,4 @@ void free(Buffer buffer) {
allocator().free(buffer);
}
Buffer CommonAllocator::malloc(size_t size) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
}
void CommonAllocator::free(Buffer buffer) {
std::free(buffer.ptr());
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}
} // namespace mlx::core::allocator

View File

@ -49,16 +49,4 @@ class Allocator {
Allocator& allocator();
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
private:
CommonAllocator() = default;
friend Allocator& allocator();
};
} // namespace mlx::core::allocator

View File

@ -1,38 +1,123 @@
// Copyright © 2023 Apple Inc.
#include <mutex>
#include "mlx/allocator.h"
#ifdef __APPLE__
#include "mlx/backend/no_metal/apple_memory.h"
#elif defined(__linux__)
#include "mlx/backend/no_metal/linux_memory.h"
#else
size_t get_memory_size() {
return 0;
}
#endif
namespace mlx::core {
namespace allocator {
Allocator& allocator() {
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() const {
return active_memory_;
};
size_t get_peak_memory() const {
return peak_memory_;
};
void reset_peak_memory() {
std::unique_lock lk(mutex_);
peak_memory_ = 0;
};
size_t get_memory_limit() {
return memory_limit_;
}
size_t set_memory_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(memory_limit_, limit);
return limit;
}
private:
size_t memory_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
std::mutex mutex_;
CommonAllocator() : memory_limit_(0.8 * get_memory_size()) {
if (memory_limit_ == 0) {
memory_limit_ = 1UL << 33;
}
};
friend CommonAllocator& common_allocator();
};
CommonAllocator& common_allocator() {
static CommonAllocator allocator_;
return allocator_;
}
Allocator& allocator() {
return common_allocator();
}
void* Buffer::raw_ptr() {
if (!ptr_) {
return nullptr;
}
return static_cast<size_t*>(ptr_) + 1;
}
Buffer CommonAllocator::malloc(size_t size) {
void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
std::unique_lock lk(mutex_);
active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_);
return Buffer{ptr};
}
void CommonAllocator::free(Buffer buffer) {
auto sz = size(buffer);
std::free(buffer.ptr());
std::unique_lock lk(mutex_);
active_memory_ -= sz;
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
}
} // namespace allocator
size_t get_active_memory() {
return 0;
return allocator::common_allocator().get_active_memory();
}
size_t get_peak_memory() {
return 0;
return allocator::common_allocator().get_peak_memory();
}
void reset_peak_memory() {}
size_t get_cache_memory() {
return 0;
void reset_peak_memory() {
return allocator::common_allocator().reset_peak_memory();
}
size_t set_memory_limit(size_t) {
return 0;
size_t set_memory_limit(size_t limit) {
return allocator::common_allocator().set_memory_limit(limit);
}
size_t get_memory_limit() {
return allocator::common_allocator().get_memory_limit();
}
// No-ops for common allocator
size_t get_cache_memory() {
return 0;
}
size_t set_cache_limit(size_t) {

View File

@ -0,0 +1,16 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <sys/sysctl.h>
namespace {
size_t get_memory_size() {
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
return memsize;
}
} // namespace

View File

@ -0,0 +1,22 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <sys/sysinfo.h>
namespace {
size_t get_memory_size() {
struct sysinfo info;
if (sysinfo(&info) != 0) {
return 0;
}
size_t total_ram = info.totalram;
total_ram *= info.mem_unit;
return total_ram;
}
} // namespace

View File

@ -36,7 +36,7 @@ void init_metal(nb::module_& m) {
});
metal.def("get_peak_memory", []() {
DEPRECATE("mx.metal.get_peak_memory", "mx.get_peak_memory");
return mx::get_active_memory();
return mx::get_peak_memory();
});
metal.def("reset_peak_memory", []() {
DEPRECATE("mx.metal.reset_peak_memory", "mx.reset_peak_memory");

View File

@ -177,17 +177,17 @@ class TestDistributed(mlx_tests.MLXTestCase):
def test_donation(self):
x = mx.random.normal((1024,))
mx.eval(x)
mx.synchronize(mx.default_stream(mx.default_device()))
mx.synchronize()
mx.reset_peak_memory()
scale = mx.array(2.0)
y = mx.distributed.all_sum(x)
mx.eval(y)
mx.synchronize(mx.default_stream(mx.default_device()))
mx.synchronize()
all_sum_only = mx.get_peak_memory()
y = mx.distributed.all_sum(x) * scale
mx.eval(y)
mx.synchronize(mx.default_stream(mx.default_device()))
mx.synchronize()
all_sum_with_binary = mx.get_peak_memory()
self.assertEqual(all_sum_only, all_sum_with_binary)

View File

@ -1803,7 +1803,6 @@ class TestArray(mlx_tests.MLXTestCase):
b = pickle.loads(pickle.dumps(a))
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_multi_output_leak(self):
def fun():
a = mx.zeros((2**20))

View File

@ -745,11 +745,8 @@ class TestAutograd(mlx_tests.MLXTestCase):
mx.custom_function,
mx.checkpoint,
]:
if mx.metal.is_available():
mx.synchronize(mx.default_stream(mx.default_device()))
mx.synchronize()
mem_pre = mx.get_active_memory()
else:
mem_pre = 0
def outer():
d = {}
@ -763,12 +760,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
for _ in range(5):
outer()
gc.collect()
if mx.metal.is_available():
mem_post = mx.get_active_memory()
else:
mem_post = 0
self.assertEqual(mem_pre, mem_post)
def test_grad_with_copies(self):

View File

@ -117,7 +117,6 @@ class TestEval(mlx_tests.MLXTestCase):
out = mx.vjp(fn, (x,), (y,))
out = mx.vjp(fn, (x,), (y,))
if mx.metal.is_available():
peak_mem = mx.get_peak_memory()
out = mx.vjp(fn, (x,), (y,))
self.assertEqual(peak_mem, mx.get_peak_memory())
@ -137,7 +136,6 @@ class TestEval(mlx_tests.MLXTestCase):
mx.async_eval(x)
mx.eval(a + b)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_donation_for_noops(self):
def fun(x):
s = x.shape

View File

@ -385,7 +385,7 @@ class TestLoad(mlx_tests.MLXTestCase):
mx.eval(x)
save_file = os.path.join(self.test_dir, "donation.npy")
mx.save(save_file, x)
mx.synchronize(mx.default_stream(mx.default_device()))
mx.synchronize()
mx.reset_peak_memory()
scale = mx.array(2.0)

View File

@ -7,7 +7,6 @@ import mlx_tests
class TestMemory(mlx_tests.MLXTestCase):
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_memory_info(self):
old_limit = mx.set_cache_limit(0)
@ -38,6 +37,8 @@ class TestMemory(mlx_tests.MLXTestCase):
self.assertEqual(new_active_mem, active_mem)
peak_mem = mx.get_peak_memory()
self.assertTrue(peak_mem >= 4096 * 8)
if mx.metal.is_available():
cache_mem = mx.get_cache_memory()
self.assertTrue(cache_mem >= 4096 * 4)
@ -47,6 +48,8 @@ class TestMemory(mlx_tests.MLXTestCase):
mx.reset_peak_memory()
self.assertEqual(mx.get_peak_memory(), 0)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_wired_memory(self):
old_limit = mx.set_wired_limit(1000)
old_limit = mx.set_wired_limit(0)
self.assertEqual(old_limit, 1000)

View File

@ -1901,12 +1901,12 @@ class TestOps(mlx_tests.MLXTestCase):
x = mx.cumsum(x)
return x
mx.synchronize(mx.default_stream(mx.default_device()))
mx.synchronize()
mx.eval(fn(2))
mx.synchronize(mx.default_stream(mx.default_device()))
mx.synchronize()
mem2 = mx.get_peak_memory()
mx.eval(fn(4))
mx.synchronize(mx.default_stream(mx.default_device()))
mx.synchronize()
mem4 = mx.get_peak_memory()
self.assertEqual(mem2, mem4)