fix memory count bug (#2717)

This commit is contained in:
Awni Hannun
2025-10-30 14:27:15 -07:00
committed by GitHub
parent 793a31eeb6
commit 68c5fa1c95
2 changed files with 15 additions and 2 deletions

View File

@@ -97,7 +97,6 @@ CudaAllocator::CudaAllocator()
Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_);
if (size <= small_block_size) {
size = 8;
@@ -131,7 +130,7 @@ Buffer CudaAllocator::malloc(size_t size) {
}
lock.lock();
}
active_memory_ += size;
active_memory_ += buf->size;
peak_memory_ = std::max(active_memory_, peak_memory_);
// Maintain the cache below the requested limit.

View File

@@ -58,6 +58,20 @@ class TestMemory(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
mx.set_wired_limit(max_size + 10)
def test_active_memory_count(self):
mx.synchronize()
mx.clear_cache()
init_mem = mx.get_active_memory()
a = mx.zeros((128, 128))
mx.eval(a)
mx.synchronize()
del a
a = mx.zeros((90, 128))
mx.eval(a)
mx.synchronize()
del a
self.assertEqual(init_mem, mx.get_active_memory())
if __name__ == "__main__":
mlx_tests.MLXTestRunner()