diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index 05e6fbf29..68dcd26f0 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -67,6 +67,8 @@ std::shared_ptr get_sdpa_forward_graph( return it->second; } + nvtx3::scoped_range r("get_sdpa_forward_graph"); + // Set up new graph auto graph = std::make_shared(); @@ -141,6 +143,8 @@ std::shared_ptr get_sdpa_forward_graph( // cuDNN only supports native CUDA graphs for sdpa in 9.6 or above. if (cudnnGetVersion() < 90600) { + nvtx3::scoped_range r("get_sdpa_forward_graph::graph_building"); + auto build_status = graph->build(handle, {fe::HeurMode_t::A}); if (!build_status.is_good()) { throw std::runtime_error( @@ -437,6 +441,10 @@ void ScaledDotProductAttention::eval_gpu( const auto& k = copy_unless(is_matrix_contiguous, k_pre); const auto& v = copy_unless(is_matrix_contiguous, v_pre); + for (const auto& cp : copies) { + encoder.add_temporary(cp); + } + int64_t str_oD = 1; int64_t str_oH = o.shape(3); int64_t str_oL = o.shape(1) * str_oH;