mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Add synchronize function (#1006)
* add synchronize function * fix linux * fix linux * fix and fix docs * fix test * try synchronize in stream destroy * synchronize works for both cpu and gpu
This commit is contained in:
		@@ -32,10 +32,9 @@ work.
 | 
			
		||||
 | 
			
		||||
    trace_file = "mlx_trace.gputrace"
 | 
			
		||||
 | 
			
		||||
    if not mx.metal.start_capture(trace_file):
 | 
			
		||||
      print("Make sure to run with MTL_CAPTURE_ENABLED=1 and "
 | 
			
		||||
            f"that the path {trace_file} does not already exist.")
 | 
			
		||||
      exit(1)
 | 
			
		||||
    # Make sure to run with MTL_CAPTURE_ENABLED=1 and
 | 
			
		||||
    # that the path trace_file does not already exist.
 | 
			
		||||
    mx.metal.start_capture(trace_file)
 | 
			
		||||
 | 
			
		||||
    for _ in range(10):
 | 
			
		||||
      mx.eval(mx.add(a, b))
 | 
			
		||||
 
 | 
			
		||||
@@ -16,3 +16,4 @@ Devices and Streams
 | 
			
		||||
   new_stream
 | 
			
		||||
   set_default_stream
 | 
			
		||||
   stream
 | 
			
		||||
   synchronize
 | 
			
		||||
 
 | 
			
		||||
@@ -11,7 +11,7 @@ int main() {
 | 
			
		||||
  // To use Metal debugging and profiling:
 | 
			
		||||
  // 1. Build with the MLX_METAL_DEBUG CMake option (i.e. -DMLX_METAL_DEBUG=ON).
 | 
			
		||||
  // 2. Run with MTL_CAPTURE_ENABLED=1.
 | 
			
		||||
  assert(metal::start_capture("mlx_trace.gputrace"));
 | 
			
		||||
  metal::start_capture("mlx_trace.gputrace");
 | 
			
		||||
 | 
			
		||||
  // Start at index two because the default GPU and CPU streams have indices
 | 
			
		||||
  // zero and one, respectively. This naming matches the label assigned to each
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,5 @@
 | 
			
		||||
// Copyright © 2023-2024 Apple Inc.
 | 
			
		||||
#include <cstdlib>
 | 
			
		||||
#include <future>
 | 
			
		||||
#include <memory>
 | 
			
		||||
 | 
			
		||||
#include "mlx/backend/metal/device.h"
 | 
			
		||||
@@ -115,13 +114,31 @@ std::function<void()> make_task(array arr, bool signal) {
 | 
			
		||||
  return task;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool start_capture(std::string path, id object) {
 | 
			
		||||
std::function<void()> make_synchronize_task(
 | 
			
		||||
    Stream s,
 | 
			
		||||
    std::shared_ptr<std::promise<void>> p) {
 | 
			
		||||
  return [s, p = std::move(p)]() {
 | 
			
		||||
    auto& d = metal::device(s.device);
 | 
			
		||||
    auto cb = d.get_command_buffer(s.index);
 | 
			
		||||
    if (cb == nullptr) {
 | 
			
		||||
      cb = d.new_command_buffer(s.index);
 | 
			
		||||
    } else {
 | 
			
		||||
      d.end_encoding(s.index);
 | 
			
		||||
    }
 | 
			
		||||
    d.commit_command_buffer(s.index);
 | 
			
		||||
    cb->waitUntilCompleted();
 | 
			
		||||
    check_error(cb);
 | 
			
		||||
    p->set_value();
 | 
			
		||||
  };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void start_capture(std::string path, id object) {
 | 
			
		||||
  auto pool = new_scoped_memory_pool();
 | 
			
		||||
 | 
			
		||||
  auto descriptor = MTL::CaptureDescriptor::alloc()->init();
 | 
			
		||||
  descriptor->setCaptureObject(object);
 | 
			
		||||
 | 
			
		||||
  if (path.length() > 0) {
 | 
			
		||||
  if (!path.empty()) {
 | 
			
		||||
    auto string = NS::String::string(path.c_str(), NS::UTF8StringEncoding);
 | 
			
		||||
    auto url = NS::URL::fileURLWithPath(string);
 | 
			
		||||
    descriptor->setDestination(MTL::CaptureDestinationGPUTraceDocument);
 | 
			
		||||
@@ -129,15 +146,24 @@ bool start_capture(std::string path, id object) {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto manager = MTL::CaptureManager::sharedCaptureManager();
 | 
			
		||||
  return manager->startCapture(descriptor, nullptr);
 | 
			
		||||
  NS::Error* error;
 | 
			
		||||
  bool started = manager->startCapture(descriptor, &error);
 | 
			
		||||
  descriptor->release();
 | 
			
		||||
  if (!started) {
 | 
			
		||||
    std::ostringstream msg;
 | 
			
		||||
    msg << "[metal::start_capture] Failed to start: "
 | 
			
		||||
        << error->localizedDescription()->utf8String();
 | 
			
		||||
    throw std::runtime_error(msg.str());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool start_capture(std::string path) {
 | 
			
		||||
void start_capture(std::string path) {
 | 
			
		||||
  auto& device = metal::device(mlx::core::Device::gpu);
 | 
			
		||||
  return start_capture(path, device.mtl_device());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void stop_capture() {
 | 
			
		||||
  auto pool = new_scoped_memory_pool();
 | 
			
		||||
  auto manager = MTL::CaptureManager::sharedCaptureManager();
 | 
			
		||||
  manager->stopCapture();
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -55,7 +55,7 @@ size_t set_memory_limit(size_t limit, bool relaxed = true);
 | 
			
		||||
size_t set_cache_limit(size_t limit);
 | 
			
		||||
 | 
			
		||||
/** Capture a GPU trace, saving it to an absolute file `path` */
 | 
			
		||||
bool start_capture(std::string path = "");
 | 
			
		||||
void start_capture(std::string path = "");
 | 
			
		||||
void stop_capture();
 | 
			
		||||
 | 
			
		||||
} // namespace mlx::core::metal
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@
 | 
			
		||||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <future>
 | 
			
		||||
#include <memory>
 | 
			
		||||
 | 
			
		||||
#include "mlx/array.h"
 | 
			
		||||
@@ -15,4 +16,8 @@ std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
 | 
			
		||||
 | 
			
		||||
std::function<void()> make_task(array arr, bool signal);
 | 
			
		||||
 | 
			
		||||
std::function<void()> make_synchronize_task(
 | 
			
		||||
    Stream s,
 | 
			
		||||
    std::shared_ptr<std::promise<void>> p);
 | 
			
		||||
 | 
			
		||||
} // namespace mlx::core::metal
 | 
			
		||||
 
 | 
			
		||||
@@ -22,6 +22,14 @@ std::function<void()> make_task(array arr, bool signal) {
 | 
			
		||||
      "[metal::make_task] Cannot make GPU task without metal backend");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::function<void()> make_synchronize_task(
 | 
			
		||||
    Stream s,
 | 
			
		||||
    std::shared_ptr<std::promise<void>> p) {
 | 
			
		||||
  throw std::runtime_error(
 | 
			
		||||
      "[metal::make_synchronize_task] Cannot synchronize GPU"
 | 
			
		||||
      " without metal backend");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// No-ops when Metal is not available.
 | 
			
		||||
size_t get_active_memory() {
 | 
			
		||||
  return 0;
 | 
			
		||||
@@ -38,9 +46,7 @@ size_t set_memory_limit(size_t, bool) {
 | 
			
		||||
size_t set_cache_limit(size_t) {
 | 
			
		||||
  return 0;
 | 
			
		||||
}
 | 
			
		||||
bool start_capture(std::string path) {
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
void start_capture(std::string path) {}
 | 
			
		||||
void stop_capture() {}
 | 
			
		||||
 | 
			
		||||
} // namespace mlx::core::metal
 | 
			
		||||
 
 | 
			
		||||
@@ -33,6 +33,21 @@ Stream new_stream() {
 | 
			
		||||
  return scheduler::scheduler().new_stream(default_device());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void synchronize(Stream s) {
 | 
			
		||||
  auto p = std::make_shared<std::promise<void>>();
 | 
			
		||||
  std::future<void> f = p->get_future();
 | 
			
		||||
  if (s.device == mlx::core::Device::cpu) {
 | 
			
		||||
    scheduler::enqueue(s, [p = std::move(p)]() { p->set_value(); });
 | 
			
		||||
  } else {
 | 
			
		||||
    scheduler::enqueue(s, metal::make_synchronize_task(s, std::move(p)));
 | 
			
		||||
  }
 | 
			
		||||
  f.wait();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void synchronize() {
 | 
			
		||||
  synchronize(default_stream(default_device()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace scheduler {
 | 
			
		||||
 | 
			
		||||
/** A singleton scheduler to manage devices, streams, and task execution. */
 | 
			
		||||
 
 | 
			
		||||
@@ -27,6 +27,7 @@ struct StreamThread {
 | 
			
		||||
      : stop(false), stream(stream), thread(&StreamThread::thread_fn, this) {}
 | 
			
		||||
 | 
			
		||||
  ~StreamThread() {
 | 
			
		||||
    synchronize(stream);
 | 
			
		||||
    {
 | 
			
		||||
      std::unique_lock<std::mutex> lk(mtx);
 | 
			
		||||
      stop = true;
 | 
			
		||||
 
 | 
			
		||||
@@ -29,4 +29,10 @@ inline bool operator!=(const Stream& lhs, const Stream& rhs) {
 | 
			
		||||
  return !(lhs == rhs);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Synchronize with the default stream. */
 | 
			
		||||
void synchronize();
 | 
			
		||||
 | 
			
		||||
/* Synchronize with the provided stream. */
 | 
			
		||||
void synchronize(Stream);
 | 
			
		||||
 | 
			
		||||
} // namespace mlx::core
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@
 | 
			
		||||
 | 
			
		||||
#include "mlx/backend/metal/metal.h"
 | 
			
		||||
#include <nanobind/nanobind.h>
 | 
			
		||||
#include <nanobind/stl/optional.h>
 | 
			
		||||
#include <nanobind/stl/string.h>
 | 
			
		||||
 | 
			
		||||
namespace nb = nanobind;
 | 
			
		||||
@@ -99,9 +100,6 @@ void init_metal(nb::module_& m) {
 | 
			
		||||
      Args:
 | 
			
		||||
        path (str): The path to save the capture which should have
 | 
			
		||||
          the extension ``.gputrace``.
 | 
			
		||||
 | 
			
		||||
      Returns:
 | 
			
		||||
        bool: Whether the capture was successfully started.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  metal.def(
 | 
			
		||||
      "stop_capture",
 | 
			
		||||
 
 | 
			
		||||
@@ -129,4 +129,17 @@ void init_stream(nb::module_& m) {
 | 
			
		||||
              # Operations here will use mx.cpu by default.
 | 
			
		||||
              pass
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "synchronize",
 | 
			
		||||
      [](const std::optional<Stream>& s) {
 | 
			
		||||
        s ? synchronize(s.value()) : synchronize();
 | 
			
		||||
      },
 | 
			
		||||
      "stream"_a = nb::none(),
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
      Synchronize with the given stream.
 | 
			
		||||
 | 
			
		||||
      Args:
 | 
			
		||||
        (Stream, optional): The stream to synchronize with. If ``None`` then
 | 
			
		||||
           the default stream of the default device is used. Default: ``None``.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -24,14 +24,16 @@ class TestMetal(mlx_tests.MLXTestCase):
 | 
			
		||||
        self.assertTrue(mx.metal.set_memory_limit(old_limit), old_limit)
 | 
			
		||||
 | 
			
		||||
        # Query active and peak memory
 | 
			
		||||
        a = mx.zeros((4096,), stream=mx.cpu)
 | 
			
		||||
        a = mx.zeros((4096,))
 | 
			
		||||
        mx.eval(a)
 | 
			
		||||
        mx.synchronize()
 | 
			
		||||
        active_mem = mx.metal.get_active_memory()
 | 
			
		||||
        self.assertTrue(active_mem >= 4096 * 4)
 | 
			
		||||
 | 
			
		||||
        b = mx.zeros((4096,), stream=mx.cpu)
 | 
			
		||||
        b = mx.zeros((4096,))
 | 
			
		||||
        mx.eval(b)
 | 
			
		||||
        del b
 | 
			
		||||
        mx.synchronize()
 | 
			
		||||
 | 
			
		||||
        new_active_mem = mx.metal.get_active_memory()
 | 
			
		||||
        self.assertEqual(new_active_mem, active_mem)
 | 
			
		||||
 
 | 
			
		||||
@@ -495,16 +495,16 @@ TEST_CASE("test metal memory info") {
 | 
			
		||||
 | 
			
		||||
  // Query active and peak memory
 | 
			
		||||
  {
 | 
			
		||||
    // Do these tests on the CPU since deallocation is synchronized
 | 
			
		||||
    // with the main thread.
 | 
			
		||||
    auto a = zeros({4096}, Device::cpu);
 | 
			
		||||
    auto a = zeros({4096});
 | 
			
		||||
    eval(a);
 | 
			
		||||
    synchronize();
 | 
			
		||||
    auto active_mem = metal::get_active_memory();
 | 
			
		||||
    CHECK(active_mem >= 4096 * 4);
 | 
			
		||||
    {
 | 
			
		||||
      auto b = zeros({4096}, Device::cpu);
 | 
			
		||||
      auto b = zeros({4096});
 | 
			
		||||
      eval(b);
 | 
			
		||||
    }
 | 
			
		||||
    synchronize();
 | 
			
		||||
    auto new_active_mem = metal::get_active_memory();
 | 
			
		||||
    CHECK_EQ(new_active_mem, active_mem);
 | 
			
		||||
    auto peak_mem = metal::get_peak_memory();
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user