Add synchronize function (#1006)

* add synchronize function

* fix linux

* fix linux

* fix and fix docs

* fix test

* try synchronize in stream destroy

* synchronize works for both cpu and gpu
This commit is contained in:
Awni Hannun
2024-04-22 08:25:46 -07:00
committed by GitHub
parent b0012cdd0f
commit 3d405fb3b1
14 changed files with 95 additions and 23 deletions

View File

@@ -2,6 +2,7 @@
#include "mlx/backend/metal/metal.h"
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
namespace nb = nanobind;
@@ -99,9 +100,6 @@ void init_metal(nb::module_& m) {
Args:
path (str): The path to save the capture which should have
the extension ``.gputrace``.
Returns:
bool: Whether the capture was successfully started.
)pbdoc");
metal.def(
"stop_capture",

View File

@@ -129,4 +129,17 @@ void init_stream(nb::module_& m) {
# Operations here will use mx.cpu by default.
pass
)pbdoc");
m.def(
"synchronize",
[](const std::optional<Stream>& s) {
s ? synchronize(s.value()) : synchronize();
},
"stream"_a = nb::none(),
R"pbdoc(
Synchronize with the given stream.
Args:
(Stream, optional): The stream to synchronize with. If ``None`` then
the default stream of the default device is used. Default: ``None``.
)pbdoc");
}

View File

@@ -24,14 +24,16 @@ class TestMetal(mlx_tests.MLXTestCase):
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
# Query active and peak memory
a = mx.zeros((4096,), stream=mx.cpu)
a = mx.zeros((4096,))
mx.eval(a)
mx.synchronize()
active_mem = mx.metal.get_active_memory()
self.assertTrue(active_mem >= 4096 * 4)
b = mx.zeros((4096,), stream=mx.cpu)
b = mx.zeros((4096,))
mx.eval(b)
del b
mx.synchronize()
new_active_mem = mx.metal.get_active_memory()
self.assertEqual(new_active_mem, active_mem)