mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	fix memory count bug (#2717)
This commit is contained in:
		| @@ -97,7 +97,6 @@ CudaAllocator::CudaAllocator() | ||||
|  | ||||
| Buffer CudaAllocator::malloc(size_t size) { | ||||
|   // Find available buffer from cache. | ||||
|   auto orig_size = size; | ||||
|   std::unique_lock lock(mutex_); | ||||
|   if (size <= small_block_size) { | ||||
|     size = 8; | ||||
| @@ -131,7 +130,7 @@ Buffer CudaAllocator::malloc(size_t size) { | ||||
|     } | ||||
|     lock.lock(); | ||||
|   } | ||||
|   active_memory_ += size; | ||||
|   active_memory_ += buf->size; | ||||
|   peak_memory_ = std::max(active_memory_, peak_memory_); | ||||
|  | ||||
|   // Maintain the cache below the requested limit. | ||||
|   | ||||
| @@ -58,6 +58,20 @@ class TestMemory(mlx_tests.MLXTestCase): | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.set_wired_limit(max_size + 10) | ||||
|  | ||||
|     def test_active_memory_count(self): | ||||
|         mx.synchronize() | ||||
|         mx.clear_cache() | ||||
|         init_mem = mx.get_active_memory() | ||||
|         a = mx.zeros((128, 128)) | ||||
|         mx.eval(a) | ||||
|         mx.synchronize() | ||||
|         del a | ||||
|         a = mx.zeros((90, 128)) | ||||
|         mx.eval(a) | ||||
|         mx.synchronize() | ||||
|         del a | ||||
|         self.assertEqual(init_mem, mx.get_active_memory()) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     mlx_tests.MLXTestRunner() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun