mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Add Metal debug option and capture functions (#707)
* Add Metal debug option and capture functions * Add brief Metal debugger documentation * doc nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		@@ -12,6 +12,7 @@
 | 
			
		||||
#include "mlx/backend/metal/device.h"
 | 
			
		||||
#include "mlx/backend/metal/metal.h"
 | 
			
		||||
#include "mlx/backend/metal/mps/gemm.h"
 | 
			
		||||
#include "mlx/backend/metal/utils.h"
 | 
			
		||||
 | 
			
		||||
namespace fs = std::filesystem;
 | 
			
		||||
 | 
			
		||||
@@ -145,6 +146,7 @@ void Device::new_queue(int index) {
 | 
			
		||||
  // We lock this as a critical section for safety
 | 
			
		||||
  const std::lock_guard<std::mutex> lock(mtx_);
 | 
			
		||||
  auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
 | 
			
		||||
  debug_set_stream_queue_label(q, index);
 | 
			
		||||
  if (!q) {
 | 
			
		||||
    throw std::runtime_error(
 | 
			
		||||
        "[metal::Device] Failed to make new command queue.");
 | 
			
		||||
 
 | 
			
		||||
@@ -37,11 +37,17 @@ set(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
function(build_kernel_base TARGET SRCFILE DEPS)
 | 
			
		||||
  set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
 | 
			
		||||
  if(MLX_METAL_DEBUG)
 | 
			
		||||
    set(METAL_FLAGS ${METAL_FLAGS}
 | 
			
		||||
        -gline-tables-only
 | 
			
		||||
        -frecord-sources)
 | 
			
		||||
  endif()
 | 
			
		||||
  add_custom_command(
 | 
			
		||||
    COMMAND xcrun -sdk macosx metal -Wall -Wextra
 | 
			
		||||
                  -fno-fast-math
 | 
			
		||||
                  -c ${SRCFILE} 
 | 
			
		||||
                  -I${PROJECT_SOURCE_DIR} 
 | 
			
		||||
    COMMAND xcrun -sdk macosx metal
 | 
			
		||||
                  ${METAL_FLAGS}
 | 
			
		||||
                  -c ${SRCFILE}
 | 
			
		||||
                  -I${PROJECT_SOURCE_DIR}
 | 
			
		||||
                  -o ${TARGET}.air
 | 
			
		||||
    DEPENDS ${SRCFILE} ${DEPS}
 | 
			
		||||
    OUTPUT ${TARGET}.air
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@
 | 
			
		||||
#include <memory>
 | 
			
		||||
 | 
			
		||||
#include "mlx/backend/metal/device.h"
 | 
			
		||||
#include "mlx/backend/metal/utils.h"
 | 
			
		||||
#include "mlx/primitives.h"
 | 
			
		||||
#include "mlx/scheduler.h"
 | 
			
		||||
 | 
			
		||||
@@ -15,6 +16,9 @@ bool is_available() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int max_ops_per_buffer() {
 | 
			
		||||
#ifdef MLX_METAL_DEBUG
 | 
			
		||||
  return 1;
 | 
			
		||||
#else
 | 
			
		||||
  auto get_val = []() {
 | 
			
		||||
    if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
 | 
			
		||||
      return atoi(buff_str);
 | 
			
		||||
@@ -24,6 +28,7 @@ int max_ops_per_buffer() {
 | 
			
		||||
  };
 | 
			
		||||
  static int max_ops_per_buffer_ = get_val();
 | 
			
		||||
  return max_ops_per_buffer_;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
 | 
			
		||||
@@ -74,6 +79,8 @@ std::function<void()> make_task(
 | 
			
		||||
      if (arr.is_tracer()) {
 | 
			
		||||
        inputs = arr.inputs();
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      debug_set_primitive_buffer_label(command_buffer, arr.primitive());
 | 
			
		||||
      arr.primitive().eval_gpu(arr.inputs(), outputs);
 | 
			
		||||
    }
 | 
			
		||||
    std::vector<std::shared_ptr<array::Data>> buffers;
 | 
			
		||||
@@ -108,4 +115,31 @@ std::function<void()> make_task(
 | 
			
		||||
  return task;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool start_capture(std::string path, id object) {
 | 
			
		||||
  auto pool = new_scoped_memory_pool();
 | 
			
		||||
 | 
			
		||||
  auto descriptor = MTL::CaptureDescriptor::alloc()->init();
 | 
			
		||||
  descriptor->setCaptureObject(object);
 | 
			
		||||
 | 
			
		||||
  if (path.length() > 0) {
 | 
			
		||||
    auto string = NS::String::string(path.c_str(), NS::UTF8StringEncoding);
 | 
			
		||||
    auto url = NS::URL::fileURLWithPath(string);
 | 
			
		||||
    descriptor->setDestination(MTL::CaptureDestinationGPUTraceDocument);
 | 
			
		||||
    descriptor->setOutputURL(url);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto manager = MTL::CaptureManager::sharedCaptureManager();
 | 
			
		||||
  return manager->startCapture(descriptor, nullptr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool start_capture(std::string path) {
 | 
			
		||||
  auto& device = metal::device(mlx::core::Device::gpu);
 | 
			
		||||
  return start_capture(path, device.mtl_device());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void stop_capture() {
 | 
			
		||||
  auto manager = MTL::CaptureManager::sharedCaptureManager();
 | 
			
		||||
  manager->stopCapture();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace mlx::core::metal
 | 
			
		||||
 
 | 
			
		||||
@@ -66,4 +66,8 @@ std::function<void()> make_task(
 | 
			
		||||
    std::vector<std::shared_future<void>> deps,
 | 
			
		||||
    std::shared_ptr<std::promise<void>> p);
 | 
			
		||||
 | 
			
		||||
/** Capture a GPU trace, saving it to an absolute file `path` */
 | 
			
		||||
bool start_capture(std::string path = "");
 | 
			
		||||
void stop_capture();
 | 
			
		||||
 | 
			
		||||
} // namespace mlx::core::metal
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@
 | 
			
		||||
 | 
			
		||||
#include "mlx/array.h"
 | 
			
		||||
#include "mlx/backend/metal/device.h"
 | 
			
		||||
#include "mlx/primitives.h"
 | 
			
		||||
 | 
			
		||||
namespace mlx::core {
 | 
			
		||||
 | 
			
		||||
@@ -123,6 +124,29 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
 | 
			
		||||
  return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline NS::String* make_string(std::ostringstream& os) {
 | 
			
		||||
  std::string string = os.str();
 | 
			
		||||
  return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
 | 
			
		||||
#ifdef MLX_METAL_DEBUG
 | 
			
		||||
  std::ostringstream label;
 | 
			
		||||
  label << "Stream " << index;
 | 
			
		||||
  queue->setLabel(make_string(label));
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline void debug_set_primitive_buffer_label(
 | 
			
		||||
    MTL::CommandBuffer* command_buffer,
 | 
			
		||||
    Primitive& primitive) {
 | 
			
		||||
#ifdef MLX_METAL_DEBUG
 | 
			
		||||
  std::ostringstream label;
 | 
			
		||||
  primitive.print(label);
 | 
			
		||||
  command_buffer->setLabel(make_string(label));
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
} // namespace mlx::core
 | 
			
		||||
 
 | 
			
		||||
@@ -39,5 +39,9 @@ size_t set_memory_limit(size_t, bool) {
 | 
			
		||||
size_t set_cache_limit(size_t) {
 | 
			
		||||
  return 0;
 | 
			
		||||
}
 | 
			
		||||
bool start_capture(std::string path) {
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
void stop_capture() {}
 | 
			
		||||
 | 
			
		||||
} // namespace mlx::core::metal
 | 
			
		||||
 
 | 
			
		||||
@@ -28,7 +28,8 @@ class Synchronizer : public Primitive {
 | 
			
		||||
 | 
			
		||||
  void eval_cpu(const std::vector<array>&, std::vector<array>&) override{};
 | 
			
		||||
  void eval_gpu(const std::vector<array>&, std::vector<array>&) override{};
 | 
			
		||||
  void print(std::ostream&) override {}
 | 
			
		||||
 | 
			
		||||
  DEFINE_PRINT(Synchronize);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Initialize the static tracing counter from transforms_impl.h .
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user