mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add stats and limit to common allocator and enable tests (#1988)
* add stats to common allocator and enable tests * linux memory and default * fix
This commit is contained in:
parent
d343782c8b
commit
2a980a76ce
@ -4,7 +4,6 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/scheduler.h"
|
|
||||||
|
|
||||||
namespace mlx::core::allocator {
|
namespace mlx::core::allocator {
|
||||||
|
|
||||||
@ -22,23 +21,4 @@ void free(Buffer buffer) {
|
|||||||
allocator().free(buffer);
|
allocator().free(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer CommonAllocator::malloc(size_t size) {
|
|
||||||
void* ptr = std::malloc(size + sizeof(size_t));
|
|
||||||
if (ptr != nullptr) {
|
|
||||||
*static_cast<size_t*>(ptr) = size;
|
|
||||||
}
|
|
||||||
return Buffer{ptr};
|
|
||||||
}
|
|
||||||
|
|
||||||
void CommonAllocator::free(Buffer buffer) {
|
|
||||||
std::free(buffer.ptr());
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t CommonAllocator::size(Buffer buffer) const {
|
|
||||||
if (buffer.ptr() == nullptr) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
return *static_cast<size_t*>(buffer.ptr());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::allocator
|
} // namespace mlx::core::allocator
|
||||||
|
@ -49,16 +49,4 @@ class Allocator {
|
|||||||
|
|
||||||
Allocator& allocator();
|
Allocator& allocator();
|
||||||
|
|
||||||
class CommonAllocator : public Allocator {
|
|
||||||
/** A general CPU allocator. */
|
|
||||||
public:
|
|
||||||
virtual Buffer malloc(size_t size) override;
|
|
||||||
virtual void free(Buffer buffer) override;
|
|
||||||
virtual size_t size(Buffer buffer) const override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
CommonAllocator() = default;
|
|
||||||
friend Allocator& allocator();
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace mlx::core::allocator
|
} // namespace mlx::core::allocator
|
||||||
|
@ -1,38 +1,123 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
|
|
||||||
|
#ifdef __APPLE__
|
||||||
|
#include "mlx/backend/no_metal/apple_memory.h"
|
||||||
|
#elif defined(__linux__)
|
||||||
|
#include "mlx/backend/no_metal/linux_memory.h"
|
||||||
|
#else
|
||||||
|
size_t get_memory_size() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace allocator {
|
namespace allocator {
|
||||||
|
|
||||||
Allocator& allocator() {
|
class CommonAllocator : public Allocator {
|
||||||
|
/** A general CPU allocator. */
|
||||||
|
public:
|
||||||
|
virtual Buffer malloc(size_t size) override;
|
||||||
|
virtual void free(Buffer buffer) override;
|
||||||
|
virtual size_t size(Buffer buffer) const override;
|
||||||
|
size_t get_active_memory() const {
|
||||||
|
return active_memory_;
|
||||||
|
};
|
||||||
|
size_t get_peak_memory() const {
|
||||||
|
return peak_memory_;
|
||||||
|
};
|
||||||
|
void reset_peak_memory() {
|
||||||
|
std::unique_lock lk(mutex_);
|
||||||
|
peak_memory_ = 0;
|
||||||
|
};
|
||||||
|
size_t get_memory_limit() {
|
||||||
|
return memory_limit_;
|
||||||
|
}
|
||||||
|
size_t set_memory_limit(size_t limit) {
|
||||||
|
std::unique_lock lk(mutex_);
|
||||||
|
std::swap(memory_limit_, limit);
|
||||||
|
return limit;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t memory_limit_;
|
||||||
|
size_t active_memory_{0};
|
||||||
|
size_t peak_memory_{0};
|
||||||
|
std::mutex mutex_;
|
||||||
|
CommonAllocator() : memory_limit_(0.8 * get_memory_size()) {
|
||||||
|
if (memory_limit_ == 0) {
|
||||||
|
memory_limit_ = 1UL << 33;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
friend CommonAllocator& common_allocator();
|
||||||
|
};
|
||||||
|
|
||||||
|
CommonAllocator& common_allocator() {
|
||||||
static CommonAllocator allocator_;
|
static CommonAllocator allocator_;
|
||||||
return allocator_;
|
return allocator_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Allocator& allocator() {
|
||||||
|
return common_allocator();
|
||||||
|
}
|
||||||
|
|
||||||
void* Buffer::raw_ptr() {
|
void* Buffer::raw_ptr() {
|
||||||
if (!ptr_) {
|
if (!ptr_) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return static_cast<size_t*>(ptr_) + 1;
|
return static_cast<size_t*>(ptr_) + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Buffer CommonAllocator::malloc(size_t size) {
|
||||||
|
void* ptr = std::malloc(size + sizeof(size_t));
|
||||||
|
if (ptr != nullptr) {
|
||||||
|
*static_cast<size_t*>(ptr) = size;
|
||||||
|
}
|
||||||
|
std::unique_lock lk(mutex_);
|
||||||
|
active_memory_ += size;
|
||||||
|
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||||
|
return Buffer{ptr};
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommonAllocator::free(Buffer buffer) {
|
||||||
|
auto sz = size(buffer);
|
||||||
|
std::free(buffer.ptr());
|
||||||
|
std::unique_lock lk(mutex_);
|
||||||
|
active_memory_ -= sz;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t CommonAllocator::size(Buffer buffer) const {
|
||||||
|
if (buffer.ptr() == nullptr) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return *static_cast<size_t*>(buffer.ptr());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace allocator
|
} // namespace allocator
|
||||||
|
|
||||||
size_t get_active_memory() {
|
size_t get_active_memory() {
|
||||||
return 0;
|
return allocator::common_allocator().get_active_memory();
|
||||||
}
|
}
|
||||||
size_t get_peak_memory() {
|
size_t get_peak_memory() {
|
||||||
return 0;
|
return allocator::common_allocator().get_peak_memory();
|
||||||
}
|
}
|
||||||
void reset_peak_memory() {}
|
void reset_peak_memory() {
|
||||||
size_t get_cache_memory() {
|
return allocator::common_allocator().reset_peak_memory();
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
size_t set_memory_limit(size_t) {
|
size_t set_memory_limit(size_t limit) {
|
||||||
return 0;
|
return allocator::common_allocator().set_memory_limit(limit);
|
||||||
}
|
}
|
||||||
size_t get_memory_limit() {
|
size_t get_memory_limit() {
|
||||||
|
return allocator::common_allocator().get_memory_limit();
|
||||||
|
}
|
||||||
|
|
||||||
|
// No-ops for common allocator
|
||||||
|
size_t get_cache_memory() {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
size_t set_cache_limit(size_t) {
|
size_t set_cache_limit(size_t) {
|
||||||
|
16
mlx/backend/no_metal/apple_memory.h
Normal file
16
mlx/backend/no_metal/apple_memory.h
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <sys/sysctl.h>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
size_t get_memory_size() {
|
||||||
|
size_t memsize = 0;
|
||||||
|
size_t length = sizeof(memsize);
|
||||||
|
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
|
||||||
|
return memsize;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
22
mlx/backend/no_metal/linux_memory.h
Normal file
22
mlx/backend/no_metal/linux_memory.h
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <sys/sysinfo.h>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
size_t get_memory_size() {
|
||||||
|
struct sysinfo info;
|
||||||
|
|
||||||
|
if (sysinfo(&info) != 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t total_ram = info.totalram;
|
||||||
|
total_ram *= info.mem_unit;
|
||||||
|
|
||||||
|
return total_ram;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
@ -36,7 +36,7 @@ void init_metal(nb::module_& m) {
|
|||||||
});
|
});
|
||||||
metal.def("get_peak_memory", []() {
|
metal.def("get_peak_memory", []() {
|
||||||
DEPRECATE("mx.metal.get_peak_memory", "mx.get_peak_memory");
|
DEPRECATE("mx.metal.get_peak_memory", "mx.get_peak_memory");
|
||||||
return mx::get_active_memory();
|
return mx::get_peak_memory();
|
||||||
});
|
});
|
||||||
metal.def("reset_peak_memory", []() {
|
metal.def("reset_peak_memory", []() {
|
||||||
DEPRECATE("mx.metal.reset_peak_memory", "mx.reset_peak_memory");
|
DEPRECATE("mx.metal.reset_peak_memory", "mx.reset_peak_memory");
|
||||||
|
@ -177,17 +177,17 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
|||||||
def test_donation(self):
|
def test_donation(self):
|
||||||
x = mx.random.normal((1024,))
|
x = mx.random.normal((1024,))
|
||||||
mx.eval(x)
|
mx.eval(x)
|
||||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
mx.synchronize()
|
||||||
|
|
||||||
mx.reset_peak_memory()
|
mx.reset_peak_memory()
|
||||||
scale = mx.array(2.0)
|
scale = mx.array(2.0)
|
||||||
y = mx.distributed.all_sum(x)
|
y = mx.distributed.all_sum(x)
|
||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
mx.synchronize()
|
||||||
all_sum_only = mx.get_peak_memory()
|
all_sum_only = mx.get_peak_memory()
|
||||||
y = mx.distributed.all_sum(x) * scale
|
y = mx.distributed.all_sum(x) * scale
|
||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
mx.synchronize()
|
||||||
all_sum_with_binary = mx.get_peak_memory()
|
all_sum_with_binary = mx.get_peak_memory()
|
||||||
|
|
||||||
self.assertEqual(all_sum_only, all_sum_with_binary)
|
self.assertEqual(all_sum_only, all_sum_with_binary)
|
||||||
|
@ -1803,7 +1803,6 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
b = pickle.loads(pickle.dumps(a))
|
b = pickle.loads(pickle.dumps(a))
|
||||||
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
|
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
|
||||||
|
|
||||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
|
||||||
def test_multi_output_leak(self):
|
def test_multi_output_leak(self):
|
||||||
def fun():
|
def fun():
|
||||||
a = mx.zeros((2**20))
|
a = mx.zeros((2**20))
|
||||||
|
@ -745,11 +745,8 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
mx.custom_function,
|
mx.custom_function,
|
||||||
mx.checkpoint,
|
mx.checkpoint,
|
||||||
]:
|
]:
|
||||||
if mx.metal.is_available():
|
mx.synchronize()
|
||||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
|
||||||
mem_pre = mx.get_active_memory()
|
mem_pre = mx.get_active_memory()
|
||||||
else:
|
|
||||||
mem_pre = 0
|
|
||||||
|
|
||||||
def outer():
|
def outer():
|
||||||
d = {}
|
d = {}
|
||||||
@ -763,12 +760,7 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
outer()
|
outer()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
if mx.metal.is_available():
|
|
||||||
mem_post = mx.get_active_memory()
|
mem_post = mx.get_active_memory()
|
||||||
else:
|
|
||||||
mem_post = 0
|
|
||||||
|
|
||||||
self.assertEqual(mem_pre, mem_post)
|
self.assertEqual(mem_pre, mem_post)
|
||||||
|
|
||||||
def test_grad_with_copies(self):
|
def test_grad_with_copies(self):
|
||||||
|
@ -117,7 +117,6 @@ class TestEval(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
out = mx.vjp(fn, (x,), (y,))
|
out = mx.vjp(fn, (x,), (y,))
|
||||||
out = mx.vjp(fn, (x,), (y,))
|
out = mx.vjp(fn, (x,), (y,))
|
||||||
if mx.metal.is_available():
|
|
||||||
peak_mem = mx.get_peak_memory()
|
peak_mem = mx.get_peak_memory()
|
||||||
out = mx.vjp(fn, (x,), (y,))
|
out = mx.vjp(fn, (x,), (y,))
|
||||||
self.assertEqual(peak_mem, mx.get_peak_memory())
|
self.assertEqual(peak_mem, mx.get_peak_memory())
|
||||||
@ -137,7 +136,6 @@ class TestEval(mlx_tests.MLXTestCase):
|
|||||||
mx.async_eval(x)
|
mx.async_eval(x)
|
||||||
mx.eval(a + b)
|
mx.eval(a + b)
|
||||||
|
|
||||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
|
||||||
def test_donation_for_noops(self):
|
def test_donation_for_noops(self):
|
||||||
def fun(x):
|
def fun(x):
|
||||||
s = x.shape
|
s = x.shape
|
||||||
|
@ -385,7 +385,7 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
mx.eval(x)
|
mx.eval(x)
|
||||||
save_file = os.path.join(self.test_dir, "donation.npy")
|
save_file = os.path.join(self.test_dir, "donation.npy")
|
||||||
mx.save(save_file, x)
|
mx.save(save_file, x)
|
||||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
mx.synchronize()
|
||||||
|
|
||||||
mx.reset_peak_memory()
|
mx.reset_peak_memory()
|
||||||
scale = mx.array(2.0)
|
scale = mx.array(2.0)
|
||||||
|
@ -7,7 +7,6 @@ import mlx_tests
|
|||||||
|
|
||||||
|
|
||||||
class TestMemory(mlx_tests.MLXTestCase):
|
class TestMemory(mlx_tests.MLXTestCase):
|
||||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
|
||||||
def test_memory_info(self):
|
def test_memory_info(self):
|
||||||
old_limit = mx.set_cache_limit(0)
|
old_limit = mx.set_cache_limit(0)
|
||||||
|
|
||||||
@ -38,6 +37,8 @@ class TestMemory(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(new_active_mem, active_mem)
|
self.assertEqual(new_active_mem, active_mem)
|
||||||
peak_mem = mx.get_peak_memory()
|
peak_mem = mx.get_peak_memory()
|
||||||
self.assertTrue(peak_mem >= 4096 * 8)
|
self.assertTrue(peak_mem >= 4096 * 8)
|
||||||
|
|
||||||
|
if mx.metal.is_available():
|
||||||
cache_mem = mx.get_cache_memory()
|
cache_mem = mx.get_cache_memory()
|
||||||
self.assertTrue(cache_mem >= 4096 * 4)
|
self.assertTrue(cache_mem >= 4096 * 4)
|
||||||
|
|
||||||
@ -47,6 +48,8 @@ class TestMemory(mlx_tests.MLXTestCase):
|
|||||||
mx.reset_peak_memory()
|
mx.reset_peak_memory()
|
||||||
self.assertEqual(mx.get_peak_memory(), 0)
|
self.assertEqual(mx.get_peak_memory(), 0)
|
||||||
|
|
||||||
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
|
def test_wired_memory(self):
|
||||||
old_limit = mx.set_wired_limit(1000)
|
old_limit = mx.set_wired_limit(1000)
|
||||||
old_limit = mx.set_wired_limit(0)
|
old_limit = mx.set_wired_limit(0)
|
||||||
self.assertEqual(old_limit, 1000)
|
self.assertEqual(old_limit, 1000)
|
||||||
|
@ -1901,12 +1901,12 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
x = mx.cumsum(x)
|
x = mx.cumsum(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
mx.synchronize()
|
||||||
mx.eval(fn(2))
|
mx.eval(fn(2))
|
||||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
mx.synchronize()
|
||||||
mem2 = mx.get_peak_memory()
|
mem2 = mx.get_peak_memory()
|
||||||
mx.eval(fn(4))
|
mx.eval(fn(4))
|
||||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
mx.synchronize()
|
||||||
mem4 = mx.get_peak_memory()
|
mem4 = mx.get_peak_memory()
|
||||||
self.assertEqual(mem2, mem4)
|
self.assertEqual(mem2, mem4)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user