mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
Add Metal debug option and capture functions (#707)
* Add Metal debug option and capture functions * Add brief Metal debugger documentation * doc nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
@@ -145,6 +146,7 @@ void Device::new_queue(int index) {
|
||||
// We lock this as a critical section for safety
|
||||
const std::lock_guard<std::mutex> lock(mtx_);
|
||||
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
|
||||
debug_set_stream_queue_label(q, index);
|
||||
if (!q) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Failed to make new command queue.");
|
||||
|
@@ -37,11 +37,17 @@ set(
|
||||
)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS}
|
||||
-gline-tables-only
|
||||
-frecord-sources)
|
||||
endif()
|
||||
add_custom_command(
|
||||
COMMAND xcrun -sdk macosx metal -Wall -Wextra
|
||||
-fno-fast-math
|
||||
-c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR}
|
||||
COMMAND xcrun -sdk macosx metal
|
||||
${METAL_FLAGS}
|
||||
-c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR}
|
||||
-o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS}
|
||||
OUTPUT ${TARGET}.air
|
||||
|
@@ -5,6 +5,7 @@
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
@@ -15,6 +16,9 @@ bool is_available() {
|
||||
}
|
||||
|
||||
int max_ops_per_buffer() {
|
||||
#ifdef MLX_METAL_DEBUG
|
||||
return 1;
|
||||
#else
|
||||
auto get_val = []() {
|
||||
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
|
||||
return atoi(buff_str);
|
||||
@@ -24,6 +28,7 @@ int max_ops_per_buffer() {
|
||||
};
|
||||
static int max_ops_per_buffer_ = get_val();
|
||||
return max_ops_per_buffer_;
|
||||
#endif
|
||||
}
|
||||
|
||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
||||
@@ -74,6 +79,8 @@ std::function<void()> make_task(
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
|
||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
std::vector<std::shared_ptr<array::Data>> buffers;
|
||||
@@ -108,4 +115,31 @@ std::function<void()> make_task(
|
||||
return task;
|
||||
}
|
||||
|
||||
bool start_capture(std::string path, id object) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
auto descriptor = MTL::CaptureDescriptor::alloc()->init();
|
||||
descriptor->setCaptureObject(object);
|
||||
|
||||
if (path.length() > 0) {
|
||||
auto string = NS::String::string(path.c_str(), NS::UTF8StringEncoding);
|
||||
auto url = NS::URL::fileURLWithPath(string);
|
||||
descriptor->setDestination(MTL::CaptureDestinationGPUTraceDocument);
|
||||
descriptor->setOutputURL(url);
|
||||
}
|
||||
|
||||
auto manager = MTL::CaptureManager::sharedCaptureManager();
|
||||
return manager->startCapture(descriptor, nullptr);
|
||||
}
|
||||
|
||||
bool start_capture(std::string path) {
|
||||
auto& device = metal::device(mlx::core::Device::gpu);
|
||||
return start_capture(path, device.mtl_device());
|
||||
}
|
||||
|
||||
void stop_capture() {
|
||||
auto manager = MTL::CaptureManager::sharedCaptureManager();
|
||||
manager->stopCapture();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -66,4 +66,8 @@ std::function<void()> make_task(
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p);
|
||||
|
||||
/** Capture a GPU trace, saving it to an absolute file `path` */
|
||||
bool start_capture(std::string path = "");
|
||||
void stop_capture();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -123,6 +124,29 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
||||
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
||||
}
|
||||
|
||||
inline NS::String* make_string(std::ostringstream& os) {
|
||||
std::string string = os.str();
|
||||
return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
|
||||
}
|
||||
|
||||
inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
|
||||
#ifdef MLX_METAL_DEBUG
|
||||
std::ostringstream label;
|
||||
label << "Stream " << index;
|
||||
queue->setLabel(make_string(label));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void debug_set_primitive_buffer_label(
|
||||
MTL::CommandBuffer* command_buffer,
|
||||
Primitive& primitive) {
|
||||
#ifdef MLX_METAL_DEBUG
|
||||
std::ostringstream label;
|
||||
primitive.print(label);
|
||||
command_buffer->setLabel(make_string(label));
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -39,5 +39,9 @@ size_t set_memory_limit(size_t, bool) {
|
||||
size_t set_cache_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
bool start_capture(std::string path) {
|
||||
return false;
|
||||
}
|
||||
void stop_capture() {}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -28,7 +28,8 @@ class Synchronizer : public Primitive {
|
||||
|
||||
void eval_cpu(const std::vector<array>&, std::vector<array>&) override{};
|
||||
void eval_gpu(const std::vector<array>&, std::vector<array>&) override{};
|
||||
void print(std::ostream&) override {}
|
||||
|
||||
DEFINE_PRINT(Synchronize);
|
||||
};
|
||||
|
||||
// Initialize the static tracing counter from transforms_impl.h .
|
||||
|
Reference in New Issue
Block a user