Reset peak memory (#1074)

* reset peak memory

* fix linux

* nits in docs
This commit is contained in:
Awni Hannun 2024-05-03 17:12:51 -07:00 committed by GitHub
parent 79c859e2e0
commit 21623156a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 26 additions and 4 deletions

View File

@ -10,6 +10,7 @@ Metal
device_info device_info
get_active_memory get_active_memory
get_peak_memory get_peak_memory
reset_peak_memory
get_cache_memory get_cache_memory
set_memory_limit set_memory_limit
set_cache_limit set_cache_limit

View File

@ -258,6 +258,9 @@ size_t get_active_memory() {
size_t get_peak_memory() { size_t get_peak_memory() {
return allocator().get_peak_memory(); return allocator().get_peak_memory();
} }
void reset_peak_memory() {
allocator().reset_peak_memory();
}
size_t get_cache_memory() { size_t get_cache_memory() {
return allocator().get_cache_memory(); return allocator().get_cache_memory();
} }

View File

@ -62,6 +62,10 @@ class MetalAllocator : public allocator::Allocator {
size_t get_peak_memory() { size_t get_peak_memory() {
return peak_memory_; return peak_memory_;
}; };
void reset_peak_memory() {
std::unique_lock lk(mutex_);
peak_memory_ = 0;
};
size_t get_cache_memory() { size_t get_cache_memory() {
return buffer_cache_.cache_size(); return buffer_cache_.cache_size();
}; };

View File

@ -20,11 +20,15 @@ size_t get_active_memory();
/* Get the peak amount of used memory in bytes. /* Get the peak amount of used memory in bytes.
* *
* The maximum memory used is recorded from the beginning of the program * The maximum memory used recorded from the beginning of the program
* execution. * execution or since the last call to reset_peak_memory.
* */ * */
size_t get_peak_memory(); size_t get_peak_memory();
/* Reset the peak memory to zero.
* */
void reset_peak_memory();
/* Get the cache size in bytes. /* Get the cache size in bytes.
* *
* The cache includes memory not currently used that has not been returned * The cache includes memory not currently used that has not been returned

View File

@ -36,6 +36,7 @@ size_t get_active_memory() {
size_t get_peak_memory() { size_t get_peak_memory() {
return 0; return 0;
} }
void reset_peak_memory() {}
size_t get_cache_memory() { size_t get_cache_memory() {
return 0; return 0;
} }

View File

@ -35,8 +35,14 @@ void init_metal(nb::module_& m) {
R"pbdoc( R"pbdoc(
Get the peak amount of used memory in bytes. Get the peak amount of used memory in bytes.
The maximum memory used is recorded from the beginning of the program The maximum memory used recorded from the beginning of the program
execution. execution or since the last call to :func:`reset_peak_memory`.
)pbdoc");
metal.def(
"reset_peak_memory",
&metal::reset_peak_memory,
R"pbdoc(
Reset the peak memory to zero.
)pbdoc"); )pbdoc");
metal.def( metal.def(
"get_cache_memory", "get_cache_memory",

View File

@ -45,6 +45,9 @@ class TestMetal(mlx_tests.MLXTestCase):
mx.metal.clear_cache() mx.metal.clear_cache()
self.assertEqual(mx.metal.get_cache_memory(), 0) self.assertEqual(mx.metal.get_cache_memory(), 0)
mx.metal.reset_peak_memory()
self.assertEqual(mx.metal.get_peak_memory(), 0)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()