mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add Metal debug option and capture functions (#707)
* Add Metal debug option and capture functions * Add brief Metal debugger documentation * doc nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
a7b404ff53
commit
45f636e759
@ -15,6 +15,7 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
|||||||
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||||
|
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
|
|
||||||
@ -65,8 +66,14 @@ endif()
|
|||||||
if (MLX_BUILD_METAL AND NOT METAL_LIB)
|
if (MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||||
message(STATUS "Metal not found. Unable to build GPU")
|
message(STATUS "Metal not found. Unable to build GPU")
|
||||||
set(MLX_BUILD_METAL OFF)
|
set(MLX_BUILD_METAL OFF)
|
||||||
|
set(MLX_METAL_DEBUG OFF)
|
||||||
elseif (MLX_BUILD_METAL)
|
elseif (MLX_BUILD_METAL)
|
||||||
message(STATUS "Building METAL sources")
|
message(STATUS "Building METAL sources")
|
||||||
|
|
||||||
|
if (MLX_METAL_DEBUG)
|
||||||
|
add_compile_definitions(MLX_METAL_DEBUG)
|
||||||
|
endif()
|
||||||
|
|
||||||
# Throw an error if xcrun not found
|
# Throw an error if xcrun not found
|
||||||
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||||
OUTPUT_VARIABLE MACOS_VERSION
|
OUTPUT_VARIABLE MACOS_VERSION
|
||||||
|
BIN
docs/src/_static/metal_debugger/capture.png
Normal file
BIN
docs/src/_static/metal_debugger/capture.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.2 MiB |
BIN
docs/src/_static/metal_debugger/schema.png
Normal file
BIN
docs/src/_static/metal_debugger/schema.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 746 KiB |
52
docs/src/dev/metal_debugger.rst
Normal file
52
docs/src/dev/metal_debugger.rst
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
Metal Debugger
|
||||||
|
==============
|
||||||
|
|
||||||
|
Profiling is a key step for performance optimization. You can build MLX with
|
||||||
|
the ``MLX_METAL_DEBUG`` option to improve the Metal debugging and optimization
|
||||||
|
workflow. The ``MLX_METAL_DEBUG`` debug option:
|
||||||
|
|
||||||
|
* Records source during Metal compilation, for later inspection while
|
||||||
|
debugging.
|
||||||
|
* Labels Metal objects such as command queues, improving capture readability.
|
||||||
|
|
||||||
|
The ``metal::start_capture`` function initiates a capture of all MLX GPU work.
|
||||||
|
|
||||||
|
.. code-block:: C++
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
metal::start_capture("/Users/Jane/Developer/MLX.gputrace");
|
||||||
|
|
||||||
|
auto a = arange(10.f, 20.f, 1.f, float32);
|
||||||
|
auto b = arange(30.f, 40.f, 1.f, float32);
|
||||||
|
auto c = add(a, b);
|
||||||
|
|
||||||
|
eval(c);
|
||||||
|
|
||||||
|
metal::stop_capture();
|
||||||
|
}
|
||||||
|
|
||||||
|
You can open and replay the GPU trace in Xcode. The ``Dependencies`` view
|
||||||
|
has a great overview of all operations. Checkout the `Metal debugger
|
||||||
|
documentation`_ for more information.
|
||||||
|
|
||||||
|
.. image:: ../_static/metal_debugger/capture.png
|
||||||
|
:class: dark-light
|
||||||
|
|
||||||
|
Xcode Workflow
|
||||||
|
--------------
|
||||||
|
|
||||||
|
You can skip saving to a path by running within Xcode. First, generate an Xcode
|
||||||
|
project using CMake.
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
mkdir build && cd build
|
||||||
|
cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
|
||||||
|
open mlx.xcodeproj
|
||||||
|
|
||||||
|
Select the ``metal_capture`` example schema and run.
|
||||||
|
|
||||||
|
.. image:: ../_static/metal_debugger/schema.png
|
||||||
|
:class: dark-light
|
||||||
|
|
||||||
|
.. _`Metal debugger documentation`: https://developer.apple.com/documentation/xcode/metal-debugger
|
@ -82,3 +82,4 @@ are the CPU and GPU.
|
|||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
dev/extensions
|
dev/extensions
|
||||||
|
dev/metal_debugger
|
||||||
|
@ -155,6 +155,8 @@ should point to the path to the built metal library.
|
|||||||
- ON
|
- ON
|
||||||
* - MLX_BUILD_PYTHON_BINDINGS
|
* - MLX_BUILD_PYTHON_BINDINGS
|
||||||
- OFF
|
- OFF
|
||||||
|
* - MLX_METAL_DEBUG
|
||||||
|
- OFF
|
||||||
|
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
@ -8,3 +8,4 @@ endfunction(build_example)
|
|||||||
build_example(tutorial.cpp)
|
build_example(tutorial.cpp)
|
||||||
build_example(linear_regression.cpp)
|
build_example(linear_regression.cpp)
|
||||||
build_example(logistic_regression.cpp)
|
build_example(logistic_regression.cpp)
|
||||||
|
build_example(metal_capture.cpp)
|
||||||
|
30
examples/cpp/metal_capture.cpp
Normal file
30
examples/cpp/metal_capture.cpp
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
// Enable the MLX_METAL_DEBUG CMake option to enhance the capture with groups,
|
||||||
|
// labels, etc.
|
||||||
|
assert(metal::start_capture());
|
||||||
|
|
||||||
|
// Start at index two because the default GPU and CPU streams have indices
|
||||||
|
// zero and one, respectively. This naming matches the label assigned to each
|
||||||
|
// stream's command queue.
|
||||||
|
auto s2 = new_stream(Device::gpu);
|
||||||
|
auto s3 = new_stream(Device::gpu);
|
||||||
|
|
||||||
|
auto a = arange(1.f, 10.f, 1.f, float32, s2);
|
||||||
|
auto b = arange(1.f, 10.f, 1.f, float32, s3);
|
||||||
|
auto x = add(a, a, s2);
|
||||||
|
auto y = add(b, b, s3);
|
||||||
|
|
||||||
|
// The multiply will happen on the default stream.
|
||||||
|
std::cout << multiply(x, y) << std::endl;
|
||||||
|
|
||||||
|
metal::stop_capture();
|
||||||
|
}
|
@ -12,6 +12,7 @@
|
|||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/backend/metal/mps/gemm.h"
|
#include "mlx/backend/metal/mps/gemm.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
|
||||||
namespace fs = std::filesystem;
|
namespace fs = std::filesystem;
|
||||||
|
|
||||||
@ -145,6 +146,7 @@ void Device::new_queue(int index) {
|
|||||||
// We lock this as a critical section for safety
|
// We lock this as a critical section for safety
|
||||||
const std::lock_guard<std::mutex> lock(mtx_);
|
const std::lock_guard<std::mutex> lock(mtx_);
|
||||||
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
|
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
|
||||||
|
debug_set_stream_queue_label(q, index);
|
||||||
if (!q) {
|
if (!q) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[metal::Device] Failed to make new command queue.");
|
"[metal::Device] Failed to make new command queue.");
|
||||||
|
@ -37,9 +37,15 @@ set(
|
|||||||
)
|
)
|
||||||
|
|
||||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||||
|
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
||||||
|
if(MLX_METAL_DEBUG)
|
||||||
|
set(METAL_FLAGS ${METAL_FLAGS}
|
||||||
|
-gline-tables-only
|
||||||
|
-frecord-sources)
|
||||||
|
endif()
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
COMMAND xcrun -sdk macosx metal -Wall -Wextra
|
COMMAND xcrun -sdk macosx metal
|
||||||
-fno-fast-math
|
${METAL_FLAGS}
|
||||||
-c ${SRCFILE}
|
-c ${SRCFILE}
|
||||||
-I${PROJECT_SOURCE_DIR}
|
-I${PROJECT_SOURCE_DIR}
|
||||||
-o ${TARGET}.air
|
-o ${TARGET}.air
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
@ -15,6 +16,9 @@ bool is_available() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int max_ops_per_buffer() {
|
int max_ops_per_buffer() {
|
||||||
|
#ifdef MLX_METAL_DEBUG
|
||||||
|
return 1;
|
||||||
|
#else
|
||||||
auto get_val = []() {
|
auto get_val = []() {
|
||||||
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
|
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
|
||||||
return atoi(buff_str);
|
return atoi(buff_str);
|
||||||
@ -24,6 +28,7 @@ int max_ops_per_buffer() {
|
|||||||
};
|
};
|
||||||
static int max_ops_per_buffer_ = get_val();
|
static int max_ops_per_buffer_ = get_val();
|
||||||
return max_ops_per_buffer_;
|
return max_ops_per_buffer_;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
||||||
@ -74,6 +79,8 @@ std::function<void()> make_task(
|
|||||||
if (arr.is_tracer()) {
|
if (arr.is_tracer()) {
|
||||||
inputs = arr.inputs();
|
inputs = arr.inputs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||||
}
|
}
|
||||||
std::vector<std::shared_ptr<array::Data>> buffers;
|
std::vector<std::shared_ptr<array::Data>> buffers;
|
||||||
@ -108,4 +115,31 @@ std::function<void()> make_task(
|
|||||||
return task;
|
return task;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool start_capture(std::string path, id object) {
|
||||||
|
auto pool = new_scoped_memory_pool();
|
||||||
|
|
||||||
|
auto descriptor = MTL::CaptureDescriptor::alloc()->init();
|
||||||
|
descriptor->setCaptureObject(object);
|
||||||
|
|
||||||
|
if (path.length() > 0) {
|
||||||
|
auto string = NS::String::string(path.c_str(), NS::UTF8StringEncoding);
|
||||||
|
auto url = NS::URL::fileURLWithPath(string);
|
||||||
|
descriptor->setDestination(MTL::CaptureDestinationGPUTraceDocument);
|
||||||
|
descriptor->setOutputURL(url);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto manager = MTL::CaptureManager::sharedCaptureManager();
|
||||||
|
return manager->startCapture(descriptor, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool start_capture(std::string path) {
|
||||||
|
auto& device = metal::device(mlx::core::Device::gpu);
|
||||||
|
return start_capture(path, device.mtl_device());
|
||||||
|
}
|
||||||
|
|
||||||
|
void stop_capture() {
|
||||||
|
auto manager = MTL::CaptureManager::sharedCaptureManager();
|
||||||
|
manager->stopCapture();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
@ -66,4 +66,8 @@ std::function<void()> make_task(
|
|||||||
std::vector<std::shared_future<void>> deps,
|
std::vector<std::shared_future<void>> deps,
|
||||||
std::shared_ptr<std::promise<void>> p);
|
std::shared_ptr<std::promise<void>> p);
|
||||||
|
|
||||||
|
/** Capture a GPU trace, saving it to an absolute file `path` */
|
||||||
|
bool start_capture(std::string path = "");
|
||||||
|
void stop_capture();
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -123,6 +124,29 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
|||||||
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline NS::String* make_string(std::ostringstream& os) {
|
||||||
|
std::string string = os.str();
|
||||||
|
return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
|
||||||
|
#ifdef MLX_METAL_DEBUG
|
||||||
|
std::ostringstream label;
|
||||||
|
label << "Stream " << index;
|
||||||
|
queue->setLabel(make_string(label));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void debug_set_primitive_buffer_label(
|
||||||
|
MTL::CommandBuffer* command_buffer,
|
||||||
|
Primitive& primitive) {
|
||||||
|
#ifdef MLX_METAL_DEBUG
|
||||||
|
std::ostringstream label;
|
||||||
|
primitive.print(label);
|
||||||
|
command_buffer->setLabel(make_string(label));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -39,5 +39,9 @@ size_t set_memory_limit(size_t, bool) {
|
|||||||
size_t set_cache_limit(size_t) {
|
size_t set_cache_limit(size_t) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
bool start_capture(std::string path) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
void stop_capture() {}
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
@ -28,7 +28,8 @@ class Synchronizer : public Primitive {
|
|||||||
|
|
||||||
void eval_cpu(const std::vector<array>&, std::vector<array>&) override{};
|
void eval_cpu(const std::vector<array>&, std::vector<array>&) override{};
|
||||||
void eval_gpu(const std::vector<array>&, std::vector<array>&) override{};
|
void eval_gpu(const std::vector<array>&, std::vector<array>&) override{};
|
||||||
void print(std::ostream&) override {}
|
|
||||||
|
DEFINE_PRINT(Synchronize);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Initialize the static tracing counter from transforms_impl.h .
|
// Initialize the static tracing counter from transforms_impl.h .
|
||||||
|
Loading…
Reference in New Issue
Block a user