Add Metal debug option and capture functions (#707)

* Add Metal debug option and capture functions

* Add brief Metal debugger documentation

* doc nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Jack Mousseau 2024-03-28 09:40:31 -07:00 committed by GitHub
parent a7b404ff53
commit 45f636e759
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 173 additions and 5 deletions

View File

@ -15,6 +15,7 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF) option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
@ -65,8 +66,14 @@ endif()
if (MLX_BUILD_METAL AND NOT METAL_LIB) if (MLX_BUILD_METAL AND NOT METAL_LIB)
message(STATUS "Metal not found. Unable to build GPU") message(STATUS "Metal not found. Unable to build GPU")
set(MLX_BUILD_METAL OFF) set(MLX_BUILD_METAL OFF)
set(MLX_METAL_DEBUG OFF)
elseif (MLX_BUILD_METAL) elseif (MLX_BUILD_METAL)
message(STATUS "Building METAL sources") message(STATUS "Building METAL sources")
if (MLX_METAL_DEBUG)
add_compile_definitions(MLX_METAL_DEBUG)
endif()
# Throw an error if xcrun not found # Throw an error if xcrun not found
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_VERSION OUTPUT_VARIABLE MACOS_VERSION

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 746 KiB

View File

@ -0,0 +1,52 @@
Metal Debugger
==============
Profiling is a key step for performance optimization. You can build MLX with
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and optimization
workflow. The ``MLX_METAL_DEBUG`` debug option:
* Records source during Metal compilation, for later inspection while
debugging.
* Labels Metal objects such as command queues, improving capture readability.
The ``metal::start_capture`` function initiates a capture of all MLX GPU work.
.. code-block:: C++
int main() {
metal::start_capture("/Users/Jane/Developer/MLX.gputrace");
auto a = arange(10.f, 20.f, 1.f, float32);
auto b = arange(30.f, 40.f, 1.f, float32);
auto c = add(a, b);
eval(c);
metal::stop_capture();
}
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
has a great overview of all operations. Checkout the `Metal debugger
documentation`_ for more information.
.. image:: ../_static/metal_debugger/capture.png
:class: dark-light
Xcode Workflow
--------------
You can skip saving to a path by running within Xcode. First, generate an Xcode
project using CMake.
.. code-block::
mkdir build && cd build
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
open mlx.xcodeproj
Select the ``metal_capture`` example schema and run.
.. image:: ../_static/metal_debugger/schema.png
:class: dark-light
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger

View File

@ -82,3 +82,4 @@ are the CPU and GPU.
:maxdepth: 1 :maxdepth: 1
dev/extensions dev/extensions
dev/metal_debugger

View File

@ -155,6 +155,8 @@ should point to the path to the built metal library.
- ON - ON
* - MLX_BUILD_PYTHON_BINDINGS * - MLX_BUILD_PYTHON_BINDINGS
- OFF - OFF
* - MLX_METAL_DEBUG
- OFF
.. note:: .. note::

View File

@ -8,3 +8,4 @@ endfunction(build_example)
build_example(tutorial.cpp) build_example(tutorial.cpp)
build_example(linear_regression.cpp) build_example(linear_regression.cpp)
build_example(logistic_regression.cpp) build_example(logistic_regression.cpp)
build_example(metal_capture.cpp)

View File

@ -0,0 +1,30 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include <iostream>
#include "mlx/mlx.h"
using namespace mlx::core;
int main() {
// Enable the MLX_METAL_DEBUG CMake option to enhance the capture with groups,
// labels, etc.
assert(metal::start_capture());
// Start at index two because the default GPU and CPU streams have indices
// zero and one, respectively. This naming matches the label assigned to each
// stream's command queue.
auto s2 = new_stream(Device::gpu);
auto s3 = new_stream(Device::gpu);
auto a = arange(1.f, 10.f, 1.f, float32, s2);
auto b = arange(1.f, 10.f, 1.f, float32, s3);
auto x = add(a, a, s2);
auto y = add(b, b, s3);
// The multiply will happen on the default stream.
std::cout << multiply(x, y) << std::endl;
metal::stop_capture();
}

View File

@ -12,6 +12,7 @@
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/mps/gemm.h" #include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"
namespace fs = std::filesystem; namespace fs = std::filesystem;
@ -145,6 +146,7 @@ void Device::new_queue(int index) {
// We lock this as a critical section for safety // We lock this as a critical section for safety
const std::lock_guard<std::mutex> lock(mtx_); const std::lock_guard<std::mutex> lock(mtx_);
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE); auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
debug_set_stream_queue_label(q, index);
if (!q) { if (!q) {
throw std::runtime_error( throw std::runtime_error(
"[metal::Device] Failed to make new command queue."); "[metal::Device] Failed to make new command queue.");

View File

@ -37,9 +37,15 @@ set(
) )
function(build_kernel_base TARGET SRCFILE DEPS) function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS}
-gline-tables-only
-frecord-sources)
endif()
add_custom_command( add_custom_command(
COMMAND xcrun -sdk macosx metal -Wall -Wextra COMMAND xcrun -sdk macosx metal
-fno-fast-math ${METAL_FLAGS}
-c ${SRCFILE} -c ${SRCFILE}
-I${PROJECT_SOURCE_DIR} -I${PROJECT_SOURCE_DIR}
-o ${TARGET}.air -o ${TARGET}.air

View File

@ -5,6 +5,7 @@
#include <memory> #include <memory>
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/scheduler.h" #include "mlx/scheduler.h"
@ -15,6 +16,9 @@ bool is_available() {
} }
int max_ops_per_buffer() { int max_ops_per_buffer() {
#ifdef MLX_METAL_DEBUG
return 1;
#else
auto get_val = []() { auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) { if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
return atoi(buff_str); return atoi(buff_str);
@ -24,6 +28,7 @@ int max_ops_per_buffer() {
}; };
static int max_ops_per_buffer_ = get_val(); static int max_ops_per_buffer_ = get_val();
return max_ops_per_buffer_; return max_ops_per_buffer_;
#endif
} }
#define MAX_OPS_PER_BUFFER max_ops_per_buffer() #define MAX_OPS_PER_BUFFER max_ops_per_buffer()
@ -74,6 +79,8 @@ std::function<void()> make_task(
if (arr.is_tracer()) { if (arr.is_tracer()) {
inputs = arr.inputs(); inputs = arr.inputs();
} }
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
arr.primitive().eval_gpu(arr.inputs(), outputs); arr.primitive().eval_gpu(arr.inputs(), outputs);
} }
std::vector<std::shared_ptr<array::Data>> buffers; std::vector<std::shared_ptr<array::Data>> buffers;
@ -108,4 +115,31 @@ std::function<void()> make_task(
return task; return task;
} }
bool start_capture(std::string path, id object) {
auto pool = new_scoped_memory_pool();
auto descriptor = MTL::CaptureDescriptor::alloc()->init();
descriptor->setCaptureObject(object);
if (path.length() > 0) {
auto string = NS::String::string(path.c_str(), NS::UTF8StringEncoding);
auto url = NS::URL::fileURLWithPath(string);
descriptor->setDestination(MTL::CaptureDestinationGPUTraceDocument);
descriptor->setOutputURL(url);
}
auto manager = MTL::CaptureManager::sharedCaptureManager();
return manager->startCapture(descriptor, nullptr);
}
bool start_capture(std::string path) {
auto& device = metal::device(mlx::core::Device::gpu);
return start_capture(path, device.mtl_device());
}
void stop_capture() {
auto manager = MTL::CaptureManager::sharedCaptureManager();
manager->stopCapture();
}
} // namespace mlx::core::metal } // namespace mlx::core::metal

View File

@ -66,4 +66,8 @@ std::function<void()> make_task(
std::vector<std::shared_future<void>> deps, std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p); std::shared_ptr<std::promise<void>> p);
/** Capture a GPU trace, saving it to an absolute file `path` */
bool start_capture(std::string path = "");
void stop_capture();
} // namespace mlx::core::metal } // namespace mlx::core::metal

View File

@ -4,6 +4,7 @@
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
@ -123,6 +124,29 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]}; return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
} }
inline NS::String* make_string(std::ostringstream& os) {
std::string string = os.str();
return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
}
inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
#ifdef MLX_METAL_DEBUG
std::ostringstream label;
label << "Stream " << index;
queue->setLabel(make_string(label));
#endif
}
inline void debug_set_primitive_buffer_label(
MTL::CommandBuffer* command_buffer,
Primitive& primitive) {
#ifdef MLX_METAL_DEBUG
std::ostringstream label;
primitive.print(label);
command_buffer->setLabel(make_string(label));
#endif
}
} // namespace } // namespace
} // namespace mlx::core } // namespace mlx::core

View File

@ -39,5 +39,9 @@ size_t set_memory_limit(size_t, bool) {
size_t set_cache_limit(size_t) { size_t set_cache_limit(size_t) {
return 0; return 0;
} }
bool start_capture(std::string path) {
return false;
}
void stop_capture() {}
} // namespace mlx::core::metal } // namespace mlx::core::metal

View File

@ -28,7 +28,8 @@ class Synchronizer : public Primitive {
void eval_cpu(const std::vector<array>&, std::vector<array>&) override{}; void eval_cpu(const std::vector<array>&, std::vector<array>&) override{};
void eval_gpu(const std::vector<array>&, std::vector<array>&) override{}; void eval_gpu(const std::vector<array>&, std::vector<array>&) override{};
void print(std::ostream&) override {}
DEFINE_PRINT(Synchronize);
}; };
// Initialize the static tracing counter from transforms_impl.h . // Initialize the static tracing counter from transforms_impl.h .