mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-28 21:21:21 +08:00
Add synchronize function (#1006)
* add synchronize function * fix linux * fix linux * fix and fix docs * fix test * try synchronize in stream destroy * synchronize works for both cpu and gpu
This commit is contained in:
parent
b0012cdd0f
commit
3d405fb3b1
@ -32,10 +32,9 @@ work.
|
|||||||
|
|
||||||
trace_file = "mlx_trace.gputrace"
|
trace_file = "mlx_trace.gputrace"
|
||||||
|
|
||||||
if not mx.metal.start_capture(trace_file):
|
# Make sure to run with MTL_CAPTURE_ENABLED=1 and
|
||||||
print("Make sure to run with MTL_CAPTURE_ENABLED=1 and "
|
# that the path trace_file does not already exist.
|
||||||
f"that the path {trace_file} does not already exist.")
|
mx.metal.start_capture(trace_file)
|
||||||
exit(1)
|
|
||||||
|
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
mx.eval(mx.add(a, b))
|
mx.eval(mx.add(a, b))
|
||||||
|
@ -16,3 +16,4 @@ Devices and Streams
|
|||||||
new_stream
|
new_stream
|
||||||
set_default_stream
|
set_default_stream
|
||||||
stream
|
stream
|
||||||
|
synchronize
|
||||||
|
@ -11,7 +11,7 @@ int main() {
|
|||||||
// To use Metal debugging and profiling:
|
// To use Metal debugging and profiling:
|
||||||
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
|
// 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
|
||||||
// 2. Run with MTL_CAPTURE_ENABLED=1.
|
// 2. Run with MTL_CAPTURE_ENABLED=1.
|
||||||
assert(metal::start_capture("mlx_trace.gputrace"));
|
metal::start_capture("mlx_trace.gputrace");
|
||||||
|
|
||||||
// Start at index two because the default GPU and CPU streams have indices
|
// Start at index two because the default GPU and CPU streams have indices
|
||||||
// zero and one, respectively. This naming matches the label assigned to each
|
// zero and one, respectively. This naming matches the label assigned to each
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <future>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
@ -115,13 +114,31 @@ std::function<void()> make_task(array arr, bool signal) {
|
|||||||
return task;
|
return task;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool start_capture(std::string path, id object) {
|
std::function<void()> make_synchronize_task(
|
||||||
|
Stream s,
|
||||||
|
std::shared_ptr<std::promise<void>> p) {
|
||||||
|
return [s, p = std::move(p)]() {
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
auto cb = d.get_command_buffer(s.index);
|
||||||
|
if (cb == nullptr) {
|
||||||
|
cb = d.new_command_buffer(s.index);
|
||||||
|
} else {
|
||||||
|
d.end_encoding(s.index);
|
||||||
|
}
|
||||||
|
d.commit_command_buffer(s.index);
|
||||||
|
cb->waitUntilCompleted();
|
||||||
|
check_error(cb);
|
||||||
|
p->set_value();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
void start_capture(std::string path, id object) {
|
||||||
auto pool = new_scoped_memory_pool();
|
auto pool = new_scoped_memory_pool();
|
||||||
|
|
||||||
auto descriptor = MTL::CaptureDescriptor::alloc()->init();
|
auto descriptor = MTL::CaptureDescriptor::alloc()->init();
|
||||||
descriptor->setCaptureObject(object);
|
descriptor->setCaptureObject(object);
|
||||||
|
|
||||||
if (path.length() > 0) {
|
if (!path.empty()) {
|
||||||
auto string = NS::String::string(path.c_str(), NS::UTF8StringEncoding);
|
auto string = NS::String::string(path.c_str(), NS::UTF8StringEncoding);
|
||||||
auto url = NS::URL::fileURLWithPath(string);
|
auto url = NS::URL::fileURLWithPath(string);
|
||||||
descriptor->setDestination(MTL::CaptureDestinationGPUTraceDocument);
|
descriptor->setDestination(MTL::CaptureDestinationGPUTraceDocument);
|
||||||
@ -129,15 +146,24 @@ bool start_capture(std::string path, id object) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto manager = MTL::CaptureManager::sharedCaptureManager();
|
auto manager = MTL::CaptureManager::sharedCaptureManager();
|
||||||
return manager->startCapture(descriptor, nullptr);
|
NS::Error* error;
|
||||||
|
bool started = manager->startCapture(descriptor, &error);
|
||||||
|
descriptor->release();
|
||||||
|
if (!started) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[metal::start_capture] Failed to start: "
|
||||||
|
<< error->localizedDescription()->utf8String();
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool start_capture(std::string path) {
|
void start_capture(std::string path) {
|
||||||
auto& device = metal::device(mlx::core::Device::gpu);
|
auto& device = metal::device(mlx::core::Device::gpu);
|
||||||
return start_capture(path, device.mtl_device());
|
return start_capture(path, device.mtl_device());
|
||||||
}
|
}
|
||||||
|
|
||||||
void stop_capture() {
|
void stop_capture() {
|
||||||
|
auto pool = new_scoped_memory_pool();
|
||||||
auto manager = MTL::CaptureManager::sharedCaptureManager();
|
auto manager = MTL::CaptureManager::sharedCaptureManager();
|
||||||
manager->stopCapture();
|
manager->stopCapture();
|
||||||
}
|
}
|
||||||
|
@ -55,7 +55,7 @@ size_t set_memory_limit(size_t limit, bool relaxed = true);
|
|||||||
size_t set_cache_limit(size_t limit);
|
size_t set_cache_limit(size_t limit);
|
||||||
|
|
||||||
/** Capture a GPU trace, saving it to an absolute file `path` */
|
/** Capture a GPU trace, saving it to an absolute file `path` */
|
||||||
bool start_capture(std::string path = "");
|
void start_capture(std::string path = "");
|
||||||
void stop_capture();
|
void stop_capture();
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <future>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
@ -15,4 +16,8 @@ std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
|||||||
|
|
||||||
std::function<void()> make_task(array arr, bool signal);
|
std::function<void()> make_task(array arr, bool signal);
|
||||||
|
|
||||||
|
std::function<void()> make_synchronize_task(
|
||||||
|
Stream s,
|
||||||
|
std::shared_ptr<std::promise<void>> p);
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
@ -22,6 +22,14 @@ std::function<void()> make_task(array arr, bool signal) {
|
|||||||
"[metal::make_task] Cannot make GPU task without metal backend");
|
"[metal::make_task] Cannot make GPU task without metal backend");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::function<void()> make_synchronize_task(
|
||||||
|
Stream s,
|
||||||
|
std::shared_ptr<std::promise<void>> p) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[metal::make_synchronize_task] Cannot synchronize GPU"
|
||||||
|
" without metal backend");
|
||||||
|
}
|
||||||
|
|
||||||
// No-ops when Metal is not available.
|
// No-ops when Metal is not available.
|
||||||
size_t get_active_memory() {
|
size_t get_active_memory() {
|
||||||
return 0;
|
return 0;
|
||||||
@ -38,9 +46,7 @@ size_t set_memory_limit(size_t, bool) {
|
|||||||
size_t set_cache_limit(size_t) {
|
size_t set_cache_limit(size_t) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
bool start_capture(std::string path) {
|
void start_capture(std::string path) {}
|
||||||
return false;
|
|
||||||
}
|
|
||||||
void stop_capture() {}
|
void stop_capture() {}
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
@ -33,6 +33,21 @@ Stream new_stream() {
|
|||||||
return scheduler::scheduler().new_stream(default_device());
|
return scheduler::scheduler().new_stream(default_device());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void synchronize(Stream s) {
|
||||||
|
auto p = std::make_shared<std::promise<void>>();
|
||||||
|
std::future<void> f = p->get_future();
|
||||||
|
if (s.device == mlx::core::Device::cpu) {
|
||||||
|
scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); });
|
||||||
|
} else {
|
||||||
|
scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p)));
|
||||||
|
}
|
||||||
|
f.wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
void synchronize() {
|
||||||
|
synchronize(default_stream(default_device()));
|
||||||
|
}
|
||||||
|
|
||||||
namespace scheduler {
|
namespace scheduler {
|
||||||
|
|
||||||
/** A singleton scheduler to manage devices, streams, and task execution. */
|
/** A singleton scheduler to manage devices, streams, and task execution. */
|
||||||
|
@ -27,6 +27,7 @@ struct StreamThread {
|
|||||||
: stop(false), stream(stream), thread(&StreamThread::thread_fn, this) {}
|
: stop(false), stream(stream), thread(&StreamThread::thread_fn, this) {}
|
||||||
|
|
||||||
~StreamThread() {
|
~StreamThread() {
|
||||||
|
synchronize(stream);
|
||||||
{
|
{
|
||||||
std::unique_lock<std::mutex> lk(mtx);
|
std::unique_lock<std::mutex> lk(mtx);
|
||||||
stop = true;
|
stop = true;
|
||||||
|
@ -29,4 +29,10 @@ inline bool operator!=(const Stream& lhs, const Stream& rhs) {
|
|||||||
return !(lhs == rhs);
|
return !(lhs == rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Synchronize with the default stream. */
|
||||||
|
void synchronize();
|
||||||
|
|
||||||
|
/* Synchronize with the provided stream. */
|
||||||
|
void synchronize(Stream);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include <nanobind/nanobind.h>
|
#include <nanobind/nanobind.h>
|
||||||
|
#include <nanobind/stl/optional.h>
|
||||||
#include <nanobind/stl/string.h>
|
#include <nanobind/stl/string.h>
|
||||||
|
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
@ -99,9 +100,6 @@ void init_metal(nb::module_& m) {
|
|||||||
Args:
|
Args:
|
||||||
path (str): The path to save the capture which should have
|
path (str): The path to save the capture which should have
|
||||||
the extension ``.gputrace``.
|
the extension ``.gputrace``.
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: Whether the capture was successfully started.
|
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
metal.def(
|
metal.def(
|
||||||
"stop_capture",
|
"stop_capture",
|
||||||
|
@ -129,4 +129,17 @@ void init_stream(nb::module_& m) {
|
|||||||
# Operations here will use mx.cpu by default.
|
# Operations here will use mx.cpu by default.
|
||||||
pass
|
pass
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"synchronize",
|
||||||
|
[](const std::optional<Stream>& s) {
|
||||||
|
s ? synchronize(s.value()) : synchronize();
|
||||||
|
},
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
R"pbdoc(
|
||||||
|
Synchronize with the given stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
(Stream, optional): The stream to synchronize with. If ``None`` then
|
||||||
|
the default stream of the default device is used. Default: ``None``.
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -24,14 +24,16 @@ class TestMetal(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
|
self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
|
||||||
|
|
||||||
# Query active and peak memory
|
# Query active and peak memory
|
||||||
a = mx.zeros((4096,), stream=mx.cpu)
|
a = mx.zeros((4096,))
|
||||||
mx.eval(a)
|
mx.eval(a)
|
||||||
|
mx.synchronize()
|
||||||
active_mem = mx.metal.get_active_memory()
|
active_mem = mx.metal.get_active_memory()
|
||||||
self.assertTrue(active_mem >= 4096 * 4)
|
self.assertTrue(active_mem >= 4096 * 4)
|
||||||
|
|
||||||
b = mx.zeros((4096,), stream=mx.cpu)
|
b = mx.zeros((4096,))
|
||||||
mx.eval(b)
|
mx.eval(b)
|
||||||
del b
|
del b
|
||||||
|
mx.synchronize()
|
||||||
|
|
||||||
new_active_mem = mx.metal.get_active_memory()
|
new_active_mem = mx.metal.get_active_memory()
|
||||||
self.assertEqual(new_active_mem, active_mem)
|
self.assertEqual(new_active_mem, active_mem)
|
||||||
|
@ -495,16 +495,16 @@ TEST_CASE("test metal memory info") {
|
|||||||
|
|
||||||
// Query active and peak memory
|
// Query active and peak memory
|
||||||
{
|
{
|
||||||
// Do these tests on the CPU since deallocation is synchronized
|
auto a = zeros({4096});
|
||||||
// with the main thread.
|
|
||||||
auto a = zeros({4096}, Device::cpu);
|
|
||||||
eval(a);
|
eval(a);
|
||||||
|
synchronize();
|
||||||
auto active_mem = metal::get_active_memory();
|
auto active_mem = metal::get_active_memory();
|
||||||
CHECK(active_mem >= 4096 * 4);
|
CHECK(active_mem >= 4096 * 4);
|
||||||
{
|
{
|
||||||
auto b = zeros({4096}, Device::cpu);
|
auto b = zeros({4096});
|
||||||
eval(b);
|
eval(b);
|
||||||
}
|
}
|
||||||
|
synchronize();
|
||||||
auto new_active_mem = metal::get_active_memory();
|
auto new_active_mem = metal::get_active_memory();
|
||||||
CHECK_EQ(new_active_mem, active_mem);
|
CHECK_EQ(new_active_mem, active_mem);
|
||||||
auto peak_mem = metal::get_peak_memory();
|
auto peak_mem = metal::get_peak_memory();
|
||||||
|
Loading…
Reference in New Issue
Block a user