Add Metal debug option and capture functions (#707)

* Add Metal debug option and capture functions

* Add brief Metal debugger documentation

* doc nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Jack Mousseau
2024-03-28 09:40:31 -07:00
committed by GitHub
parent a7b404ff53
commit 45f636e759
15 changed files with 173 additions and 5 deletions

View File

@@ -12,6 +12,7 @@
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"
namespace fs = std::filesystem;
@@ -145,6 +146,7 @@ void Device::new_queue(int index) {
// We lock this as a critical section for safety
const std::lock_guard<std::mutex> lock(mtx_);
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
debug_set_stream_queue_label(q, index);
if (!q) {
throw std::runtime_error(
"[metal::Device] Failed to make new command queue.");

View File

@@ -37,11 +37,17 @@ set(
)
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS}
-gline-tables-only
-frecord-sources)
endif()
add_custom_command(
COMMAND xcrun -sdk macosx metal -Wall -Wextra
-fno-fast-math
-c ${SRCFILE}
-I${PROJECT_SOURCE_DIR}
COMMAND xcrun -sdk macosx metal
${METAL_FLAGS}
-c ${SRCFILE}
-I${PROJECT_SOURCE_DIR}
-o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS}
OUTPUT ${TARGET}.air

View File

@@ -5,6 +5,7 @@
#include <memory>
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
@@ -15,6 +16,9 @@ bool is_available() {
}
int max_ops_per_buffer() {
#ifdef MLX_METAL_DEBUG
return 1;
#else
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
return atoi(buff_str);
@@ -24,6 +28,7 @@ int max_ops_per_buffer() {
};
static int max_ops_per_buffer_ = get_val();
return max_ops_per_buffer_;
#endif
}
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
@@ -74,6 +79,8 @@ std::function<void()> make_task(
if (arr.is_tracer()) {
inputs = arr.inputs();
}
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
std::vector<std::shared_ptr<array::Data>> buffers;
@@ -108,4 +115,31 @@ std::function<void()> make_task(
return task;
}
bool start_capture(std::string path, id object) {
auto pool = new_scoped_memory_pool();
auto descriptor = MTL::CaptureDescriptor::alloc()->init();
descriptor->setCaptureObject(object);
if (path.length() > 0) {
auto string = NS::String::string(path.c_str(), NS::UTF8StringEncoding);
auto url = NS::URL::fileURLWithPath(string);
descriptor->setDestination(MTL::CaptureDestinationGPUTraceDocument);
descriptor->setOutputURL(url);
}
auto manager = MTL::CaptureManager::sharedCaptureManager();
return manager->startCapture(descriptor, nullptr);
}
bool start_capture(std::string path) {
auto& device = metal::device(mlx::core::Device::gpu);
return start_capture(path, device.mtl_device());
}
void stop_capture() {
auto manager = MTL::CaptureManager::sharedCaptureManager();
manager->stopCapture();
}
} // namespace mlx::core::metal

View File

@@ -66,4 +66,8 @@ std::function<void()> make_task(
std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p);
/** Capture a GPU trace, saving it to an absolute file `path` */
bool start_capture(std::string path = "");
void stop_capture();
} // namespace mlx::core::metal

View File

@@ -4,6 +4,7 @@
#include "mlx/array.h"
#include "mlx/backend/metal/device.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -123,6 +124,29 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
}
inline NS::String* make_string(std::ostringstream& os) {
std::string string = os.str();
return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
}
inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
#ifdef MLX_METAL_DEBUG
std::ostringstream label;
label << "Stream " << index;
queue->setLabel(make_string(label));
#endif
}
inline void debug_set_primitive_buffer_label(
MTL::CommandBuffer* command_buffer,
Primitive& primitive) {
#ifdef MLX_METAL_DEBUG
std::ostringstream label;
primitive.print(label);
command_buffer->setLabel(make_string(label));
#endif
}
} // namespace
} // namespace mlx::core

View File

@@ -39,5 +39,9 @@ size_t set_memory_limit(size_t, bool) {
size_t set_cache_limit(size_t) {
return 0;
}
bool start_capture(std::string path) {
return false;
}
void stop_capture() {}
} // namespace mlx::core::metal

View File

@@ -28,7 +28,8 @@ class Synchronizer : public Primitive {
void eval_cpu(const std::vector<array>&, std::vector<array>&) override{};
void eval_gpu(const std::vector<array>&, std::vector<array>&) override{};
void print(std::ostream&) override {}
DEFINE_PRINT(Synchronize);
};
// Initialize the static tracing counter from transforms_impl.h .