mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add Metal debug option and capture functions (#707)
* Add Metal debug option and capture functions * Add brief Metal debugger documentation * doc nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
a7b404ff53
commit
45f636e759
@ -15,6 +15,7 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
||||
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
@ -65,8 +66,14 @@ endif()
|
||||
if (MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
set(MLX_METAL_DEBUG OFF)
|
||||
elseif (MLX_BUILD_METAL)
|
||||
message(STATUS "Building METAL sources")
|
||||
|
||||
if (MLX_METAL_DEBUG)
|
||||
add_compile_definitions(MLX_METAL_DEBUG)
|
||||
endif()
|
||||
|
||||
# Throw an error if xcrun not found
|
||||
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
OUTPUT_VARIABLE MACOS_VERSION
|
||||
|
BIN
docs/src/_static/metal_debugger/capture.png
Normal file
BIN
docs/src/_static/metal_debugger/capture.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.2 MiB |
BIN
docs/src/_static/metal_debugger/schema.png
Normal file
BIN
docs/src/_static/metal_debugger/schema.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 746 KiB |
52
docs/src/dev/metal_debugger.rst
Normal file
52
docs/src/dev/metal_debugger.rst
Normal file
@ -0,0 +1,52 @@
|
||||
Metal Debugger
|
||||
==============
|
||||
|
||||
Profiling is a key step for performance optimization. You can build MLX with
|
||||
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and optimization
|
||||
workflow. The ``MLX_METAL_DEBUG`` debug option:
|
||||
|
||||
* Records source during Metal compilation, for later inspection while
|
||||
debugging.
|
||||
* Labels Metal objects such as command queues, improving capture readability.
|
||||
|
||||
The ``metal::start_capture`` function initiates a capture of all MLX GPU work.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
int main() {
|
||||
metal::start_capture("/Users/Jane/Developer/MLX.gputrace");
|
||||
|
||||
auto a = arange(10.f, 20.f, 1.f, float32);
|
||||
auto b = arange(30.f, 40.f, 1.f, float32);
|
||||
auto c = add(a, b);
|
||||
|
||||
eval(c);
|
||||
|
||||
metal::stop_capture();
|
||||
}
|
||||
|
||||
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
|
||||
has a great overview of all operations. Checkout the `Metal debugger
|
||||
documentation`_ for more information.
|
||||
|
||||
.. image:: ../_static/metal_debugger/capture.png
|
||||
:class: dark-light
|
||||
|
||||
Xcode Workflow
|
||||
--------------
|
||||
|
||||
You can skip saving to a path by running within Xcode. First, generate an Xcode
|
||||
project using CMake.
|
||||
|
||||
.. code-block::
|
||||
|
||||
mkdir build && cd build
|
||||
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
|
||||
open mlx.xcodeproj
|
||||
|
||||
Select the ``metal_capture`` example schema and run.
|
||||
|
||||
.. image:: ../_static/metal_debugger/schema.png
|
||||
:class: dark-light
|
||||
|
||||
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger
|
@ -82,3 +82,4 @@ are the CPU and GPU.
|
||||
:maxdepth: 1
|
||||
|
||||
dev/extensions
|
||||
dev/metal_debugger
|
||||
|
@ -155,6 +155,8 @@ should point to the path to the built metal library.
|
||||
- ON
|
||||
* - MLX_BUILD_PYTHON_BINDINGS
|
||||
- OFF
|
||||
* - MLX_METAL_DEBUG
|
||||
- OFF
|
||||
|
||||
|
||||
.. note::
|
||||
|
@ -8,3 +8,4 @@ endfunction(build_example)
|
||||
build_example(tutorial.cpp)
|
||||
build_example(linear_regression.cpp)
|
||||
build_example(logistic_regression.cpp)
|
||||
build_example(metal_capture.cpp)
|
||||
|
30
examples/cpp/metal_capture.cpp
Normal file
30
examples/cpp/metal_capture.cpp
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
// Enable the MLX_METAL_DEBUG CMake option to enhance the capture with groups,
|
||||
// labels, etc.
|
||||
assert(metal::start_capture());
|
||||
|
||||
// Start at index two because the default GPU and CPU streams have indices
|
||||
// zero and one, respectively. This naming matches the label assigned to each
|
||||
// stream's command queue.
|
||||
auto s2 = new_stream(Device::gpu);
|
||||
auto s3 = new_stream(Device::gpu);
|
||||
|
||||
auto a = arange(1.f, 10.f, 1.f, float32, s2);
|
||||
auto b = arange(1.f, 10.f, 1.f, float32, s3);
|
||||
auto x = add(a, a, s2);
|
||||
auto y = add(b, b, s3);
|
||||
|
||||
// The multiply will happen on the default stream.
|
||||
std::cout << multiply(x, y) << std::endl;
|
||||
|
||||
metal::stop_capture();
|
||||
}
|
@ -12,6 +12,7 @@
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
@ -145,6 +146,7 @@ void Device::new_queue(int index) {
|
||||
// We lock this as a critical section for safety
|
||||
const std::lock_guard<std::mutex> lock(mtx_);
|
||||
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
|
||||
debug_set_stream_queue_label(q, index);
|
||||
if (!q) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Failed to make new command queue.");
|
||||
|
@ -37,11 +37,17 @@ set(
|
||||
)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS}
|
||||
-gline-tables-only
|
||||
-frecord-sources)
|
||||
endif()
|
||||
add_custom_command(
|
||||
COMMAND xcrun -sdk macosx metal -Wall -Wextra
|
||||
-fno-fast-math
|
||||
-c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR}
|
||||
COMMAND xcrun -sdk macosx metal
|
||||
${METAL_FLAGS}
|
||||
-c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR}
|
||||
-o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS}
|
||||
OUTPUT ${TARGET}.air
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
@ -15,6 +16,9 @@ bool is_available() {
|
||||
}
|
||||
|
||||
int max_ops_per_buffer() {
|
||||
#ifdef MLX_METAL_DEBUG
|
||||
return 1;
|
||||
#else
|
||||
auto get_val = []() {
|
||||
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
|
||||
return atoi(buff_str);
|
||||
@ -24,6 +28,7 @@ int max_ops_per_buffer() {
|
||||
};
|
||||
static int max_ops_per_buffer_ = get_val();
|
||||
return max_ops_per_buffer_;
|
||||
#endif
|
||||
}
|
||||
|
||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
||||
@ -74,6 +79,8 @@ std::function<void()> make_task(
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
|
||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
std::vector<std::shared_ptr<array::Data>> buffers;
|
||||
@ -108,4 +115,31 @@ std::function<void()> make_task(
|
||||
return task;
|
||||
}
|
||||
|
||||
bool start_capture(std::string path, id object) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
auto descriptor = MTL::CaptureDescriptor::alloc()->init();
|
||||
descriptor->setCaptureObject(object);
|
||||
|
||||
if (path.length() > 0) {
|
||||
auto string = NS::String::string(path.c_str(), NS::UTF8StringEncoding);
|
||||
auto url = NS::URL::fileURLWithPath(string);
|
||||
descriptor->setDestination(MTL::CaptureDestinationGPUTraceDocument);
|
||||
descriptor->setOutputURL(url);
|
||||
}
|
||||
|
||||
auto manager = MTL::CaptureManager::sharedCaptureManager();
|
||||
return manager->startCapture(descriptor, nullptr);
|
||||
}
|
||||
|
||||
bool start_capture(std::string path) {
|
||||
auto& device = metal::device(mlx::core::Device::gpu);
|
||||
return start_capture(path, device.mtl_device());
|
||||
}
|
||||
|
||||
void stop_capture() {
|
||||
auto manager = MTL::CaptureManager::sharedCaptureManager();
|
||||
manager->stopCapture();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@ -66,4 +66,8 @@ std::function<void()> make_task(
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p);
|
||||
|
||||
/** Capture a GPU trace, saving it to an absolute file `path` */
|
||||
bool start_capture(std::string path = "");
|
||||
void stop_capture();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@ -123,6 +124,29 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
||||
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
||||
}
|
||||
|
||||
inline NS::String* make_string(std::ostringstream& os) {
|
||||
std::string string = os.str();
|
||||
return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
|
||||
}
|
||||
|
||||
inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
|
||||
#ifdef MLX_METAL_DEBUG
|
||||
std::ostringstream label;
|
||||
label << "Stream " << index;
|
||||
queue->setLabel(make_string(label));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void debug_set_primitive_buffer_label(
|
||||
MTL::CommandBuffer* command_buffer,
|
||||
Primitive& primitive) {
|
||||
#ifdef MLX_METAL_DEBUG
|
||||
std::ostringstream label;
|
||||
primitive.print(label);
|
||||
command_buffer->setLabel(make_string(label));
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -39,5 +39,9 @@ size_t set_memory_limit(size_t, bool) {
|
||||
size_t set_cache_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
bool start_capture(std::string path) {
|
||||
return false;
|
||||
}
|
||||
void stop_capture() {}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@ -28,7 +28,8 @@ class Synchronizer : public Primitive {
|
||||
|
||||
void eval_cpu(const std::vector<array>&, std::vector<array>&) override{};
|
||||
void eval_gpu(const std::vector<array>&, std::vector<array>&) override{};
|
||||
void print(std::ostream&) override {}
|
||||
|
||||
DEFINE_PRINT(Synchronize);
|
||||
};
|
||||
|
||||
// Initialize the static tracing counter from transforms_impl.h .
|
||||
|
Loading…
Reference in New Issue
Block a user