From 923c7028a7555248f62ad51ff0492b437ef7a271 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Fri, 25 Jul 2025 13:31:02 -0700 Subject: [PATCH] Add more nvtx range for debug --- mlx/backend/cuda/scaled_dot_product_attention.cu | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index 05e6fbf29..68dcd26f0 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -67,6 +67,8 @@ std::shared_ptr get_sdpa_forward_graph( return it->second; } + nvtx3::scoped_range r("get_sdpa_forward_graph"); + // Set up new graph auto graph = std::make_shared(); @@ -141,6 +143,8 @@ std::shared_ptr get_sdpa_forward_graph( // cuDNN only supports native CUDA graphs for sdpa in 9.6 or above. if (cudnnGetVersion() < 90600) { + nvtx3::scoped_range r("get_sdpa_forward_graph::graph_building"); + auto build_status = graph->build(handle, {fe::HeurMode_t::A}); if (!build_status.is_good()) { throw std::runtime_error( @@ -437,6 +441,10 @@ void ScaledDotProductAttention::eval_gpu( const auto& k = copy_unless(is_matrix_contiguous, k_pre); const auto& v = copy_unless(is_matrix_contiguous, v_pre); + for (const auto& cp : copies) { + encoder.add_temporary(cp); + } + int64_t str_oD = 1; int64_t str_oH = o.shape(3); int64_t str_oL = o.shape(1) * str_oH;