From 52dc8c8cd58cd55b21c8e33486b6516061ab3f61 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 5 Jun 2025 11:55:12 +0900 Subject: [PATCH] Add profiler annotations in common primitives for CUDA backend (#2244) --- mlx/backend/cuda/CMakeLists.txt | 2 ++ mlx/backend/gpu/primitives.cpp | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 8c9a40d03..c991c2094 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -17,6 +17,8 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) +target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) + # Enable defining device lambda functions. target_compile_options(mlx PRIVATE "$<$:--extended-lambda>") diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index cd9296075..938923977 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -5,9 +5,17 @@ #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" +#if defined(MLX_USE_CUDA) +#include +#endif + #include +#if defined(MLX_USE_CUDA) +#define MLX_PROFILER_RANGE(message) nvtx3::scoped_range r(message) +#else #define MLX_PROFILER_RANGE(message) +#endif namespace mlx::core {