diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index e1507c631..75bd60a68 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -97,7 +97,6 @@ CudaAllocator::CudaAllocator() Buffer CudaAllocator::malloc(size_t size) { // Find available buffer from cache. - auto orig_size = size; std::unique_lock lock(mutex_); if (size <= small_block_size) { size = 8; @@ -131,7 +130,7 @@ Buffer CudaAllocator::malloc(size_t size) { } lock.lock(); } - active_memory_ += size; + active_memory_ += buf->size; peak_memory_ = std::max(active_memory_, peak_memory_); // Maintain the cache below the requested limit. diff --git a/python/tests/test_memory.py b/python/tests/test_memory.py index 08da7ccc6..a61e4d879 100644 --- a/python/tests/test_memory.py +++ b/python/tests/test_memory.py @@ -58,6 +58,20 @@ class TestMemory(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): mx.set_wired_limit(max_size + 10) + def test_active_memory_count(self): + mx.synchronize() + mx.clear_cache() + init_mem = mx.get_active_memory() + a = mx.zeros((128, 128)) + mx.eval(a) + mx.synchronize() + del a + a = mx.zeros((90, 128)) + mx.eval(a) + mx.synchronize() + del a + self.assertEqual(init_mem, mx.get_active_memory()) + if __name__ == "__main__": mlx_tests.MLXTestRunner()