mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add profiler annotations in common primitives for CUDA backend (#2244)
This commit is contained in:
parent
aede70e81d
commit
52dc8c8cd5
@ -17,6 +17,8 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
|
||||||
|
|
||||||
|
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
|
||||||
|
|
||||||
# Enable defining device lambda functions.
|
# Enable defining device lambda functions.
|
||||||
target_compile_options(mlx
|
target_compile_options(mlx
|
||||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||||
|
@ -5,9 +5,17 @@
|
|||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/gpu/slicing.h"
|
#include "mlx/backend/gpu/slicing.h"
|
||||||
|
|
||||||
|
#if defined(MLX_USE_CUDA)
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
|
#if defined(MLX_USE_CUDA)
|
||||||
|
#define MLX_PROFILER_RANGE(message) nvtx3::scoped_range r(message)
|
||||||
|
#else
|
||||||
#define MLX_PROFILER_RANGE(message)
|
#define MLX_PROFILER_RANGE(message)
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user