diff --git a/CMakeLists.txt b/CMakeLists.txt index 5b54b1d0a..20f2277d8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,6 +15,7 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON) option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF) option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_METAL "Build metal backend" ON) +option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) @@ -65,8 +66,14 @@ endif() if (MLX_BUILD_METAL AND NOT METAL_LIB) message(STATUS "Metal not found. Unable to build GPU") set(MLX_BUILD_METAL OFF) + set(MLX_METAL_DEBUG OFF) elseif (MLX_BUILD_METAL) message(STATUS "Building METAL sources") + + if (MLX_METAL_DEBUG) + add_compile_definitions(MLX_METAL_DEBUG) + endif() + # Throw an error if xcrun not found execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" OUTPUT_VARIABLE MACOS_VERSION diff --git a/docs/src/_static/metal_debugger/capture.png b/docs/src/_static/metal_debugger/capture.png new file mode 100644 index 000000000..156e9b1cf Binary files /dev/null and b/docs/src/_static/metal_debugger/capture.png differ diff --git a/docs/src/_static/metal_debugger/schema.png b/docs/src/_static/metal_debugger/schema.png new file mode 100644 index 000000000..f84ff53bf Binary files /dev/null and b/docs/src/_static/metal_debugger/schema.png differ diff --git a/docs/src/dev/metal_debugger.rst b/docs/src/dev/metal_debugger.rst new file mode 100644 index 000000000..b0d7db9d0 --- /dev/null +++ b/docs/src/dev/metal_debugger.rst @@ -0,0 +1,52 @@ +Metal Debugger +============== + +Profiling is a key step for performance optimization. You can build MLX with +the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and optimization +workflow. The ``MLX_METAL_DEBUG`` debug option: + +* Records source during Metal compilation, for later inspection while + debugging. +* Labels Metal objects such as command queues, improving capture readability. + +The ``metal::start_capture`` function initiates a capture of all MLX GPU work. + +.. code-block:: C++ + + int main() { + metal::start_capture("/Users/Jane/Developer/MLX.gputrace"); + + auto a = arange(10.f, 20.f, 1.f, float32); + auto b = arange(30.f, 40.f, 1.f, float32); + auto c = add(a, b); + + eval(c); + + metal::stop_capture(); + } + +You can open and replay the GPU trace in Xcode. The ``Dependencies`` view +has a great overview of all operations. Checkout the `Metal debugger +documentation`_ for more information. + +.. image:: ../_static/metal_debugger/capture.png + :class: dark-light + +Xcode Workflow +-------------- + +You can skip saving to a path by running within Xcode. First, generate an Xcode +project using CMake. + +.. code-block:: + + mkdir build && cd build + cmake .. -DMLX_METAL_DEBUG=ON -G Xcode + open mlx.xcodeproj + +Select the ``metal_capture`` example schema and run. + +.. image:: ../_static/metal_debugger/schema.png + :class: dark-light + +.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger diff --git a/docs/src/index.rst b/docs/src/index.rst index a9ec3899f..33d652f6d 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -82,3 +82,4 @@ are the CPU and GPU. :maxdepth: 1 dev/extensions + dev/metal_debugger diff --git a/docs/src/install.rst b/docs/src/install.rst index 43571f95d..7001d896f 100644 --- a/docs/src/install.rst +++ b/docs/src/install.rst @@ -155,6 +155,8 @@ should point to the path to the built metal library. - ON * - MLX_BUILD_PYTHON_BINDINGS - OFF + * - MLX_METAL_DEBUG + - OFF .. note:: diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index 71c6472aa..cabb723fa 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -8,3 +8,4 @@ endfunction(build_example) build_example(tutorial.cpp) build_example(linear_regression.cpp) build_example(logistic_regression.cpp) +build_example(metal_capture.cpp) diff --git a/examples/cpp/metal_capture.cpp b/examples/cpp/metal_capture.cpp new file mode 100644 index 000000000..db5514786 --- /dev/null +++ b/examples/cpp/metal_capture.cpp @@ -0,0 +1,30 @@ +// Copyright © 2024 Apple Inc. + +#include +#include + +#include "mlx/mlx.h" + +using namespace mlx::core; + +int main() { + // Enable the MLX_METAL_DEBUG CMake option to enhance the capture with groups, + // labels, etc. + assert(metal::start_capture()); + + // Start at index two because the default GPU and CPU streams have indices + // zero and one, respectively. This naming matches the label assigned to each + // stream's command queue. + auto s2 = new_stream(Device::gpu); + auto s3 = new_stream(Device::gpu); + + auto a = arange(1.f, 10.f, 1.f, float32, s2); + auto b = arange(1.f, 10.f, 1.f, float32, s3); + auto x = add(a, a, s2); + auto y = add(b, b, s3); + + // The multiply will happen on the default stream. + std::cout << multiply(x, y) << std::endl; + + metal::stop_capture(); +} diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index a2966362f..21b2930bc 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -12,6 +12,7 @@ #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/mps/gemm.h" +#include "mlx/backend/metal/utils.h" namespace fs = std::filesystem; @@ -145,6 +146,7 @@ void Device::new_queue(int index) { // We lock this as a critical section for safety const std::lock_guard lock(mtx_); auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE); + debug_set_stream_queue_label(q, index); if (!q) { throw std::runtime_error( "[metal::Device] Failed to make new command queue."); diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 262f396e3..64ee1889c 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -37,11 +37,17 @@ set( ) function(build_kernel_base TARGET SRCFILE DEPS) + set(METAL_FLAGS -Wall -Wextra -fno-fast-math) + if(MLX_METAL_DEBUG) + set(METAL_FLAGS ${METAL_FLAGS} + -gline-tables-only + -frecord-sources) + endif() add_custom_command( - COMMAND xcrun -sdk macosx metal -Wall -Wextra - -fno-fast-math - -c ${SRCFILE} - -I${PROJECT_SOURCE_DIR} + COMMAND xcrun -sdk macosx metal + ${METAL_FLAGS} + -c ${SRCFILE} + -I${PROJECT_SOURCE_DIR} -o ${TARGET}.air DEPENDS ${SRCFILE} ${DEPS} OUTPUT ${TARGET}.air diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 6035d2a7f..07cb4e900 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -5,6 +5,7 @@ #include #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" @@ -15,6 +16,9 @@ bool is_available() { } int max_ops_per_buffer() { +#ifdef MLX_METAL_DEBUG + return 1; +#else auto get_val = []() { if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) { return atoi(buff_str); @@ -24,6 +28,7 @@ int max_ops_per_buffer() { }; static int max_ops_per_buffer_ = get_val(); return max_ops_per_buffer_; +#endif } #define MAX_OPS_PER_BUFFER max_ops_per_buffer() @@ -74,6 +79,8 @@ std::function make_task( if (arr.is_tracer()) { inputs = arr.inputs(); } + + debug_set_primitive_buffer_label(command_buffer, arr.primitive()); arr.primitive().eval_gpu(arr.inputs(), outputs); } std::vector> buffers; @@ -108,4 +115,31 @@ std::function make_task( return task; } +bool start_capture(std::string path, id object) { + auto pool = new_scoped_memory_pool(); + + auto descriptor = MTL::CaptureDescriptor::alloc()->init(); + descriptor->setCaptureObject(object); + + if (path.length() > 0) { + auto string = NS::String::string(path.c_str(), NS::UTF8StringEncoding); + auto url = NS::URL::fileURLWithPath(string); + descriptor->setDestination(MTL::CaptureDestinationGPUTraceDocument); + descriptor->setOutputURL(url); + } + + auto manager = MTL::CaptureManager::sharedCaptureManager(); + return manager->startCapture(descriptor, nullptr); +} + +bool start_capture(std::string path) { + auto& device = metal::device(mlx::core::Device::gpu); + return start_capture(path, device.mtl_device()); +} + +void stop_capture() { + auto manager = MTL::CaptureManager::sharedCaptureManager(); + manager->stopCapture(); +} + } // namespace mlx::core::metal diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index 360481f81..ffbfe0ed0 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -66,4 +66,8 @@ std::function make_task( std::vector> deps, std::shared_ptr> p); +/** Capture a GPU trace, saving it to an absolute file `path` */ +bool start_capture(std::string path = ""); +void stop_capture(); + } // namespace mlx::core::metal diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 10aea8ab0..a73571914 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -4,6 +4,7 @@ #include "mlx/array.h" #include "mlx/backend/metal/device.h" +#include "mlx/primitives.h" namespace mlx::core { @@ -123,6 +124,29 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) { return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]}; } +inline NS::String* make_string(std::ostringstream& os) { + std::string string = os.str(); + return NS::String::string(string.c_str(), NS::UTF8StringEncoding); +} + +inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) { +#ifdef MLX_METAL_DEBUG + std::ostringstream label; + label << "Stream " << index; + queue->setLabel(make_string(label)); +#endif +} + +inline void debug_set_primitive_buffer_label( + MTL::CommandBuffer* command_buffer, + Primitive& primitive) { +#ifdef MLX_METAL_DEBUG + std::ostringstream label; + primitive.print(label); + command_buffer->setLabel(make_string(label)); +#endif +} + } // namespace } // namespace mlx::core diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index 240e00c41..01def113a 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -39,5 +39,9 @@ size_t set_memory_limit(size_t, bool) { size_t set_cache_limit(size_t) { return 0; } +bool start_capture(std::string path) { + return false; +} +void stop_capture() {} } // namespace mlx::core::metal diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 74b1e1b04..e66310ee8 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -28,7 +28,8 @@ class Synchronizer : public Primitive { void eval_cpu(const std::vector&, std::vector&) override{}; void eval_gpu(const std::vector&, std::vector&) override{}; - void print(std::ostream&) override {} + + DEFINE_PRINT(Synchronize); }; // Initialize the static tracing counter from transforms_impl.h .