mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Add more nvtx range for debug
This commit is contained in:
		| @@ -67,6 +67,8 @@ std::shared_ptr<fe::graph::Graph> get_sdpa_forward_graph( | ||||
|     return it->second; | ||||
|   } | ||||
|  | ||||
|   nvtx3::scoped_range r("get_sdpa_forward_graph"); | ||||
|  | ||||
|   // Set up new graph | ||||
|   auto graph = std::make_shared<fe::graph::Graph>(); | ||||
|  | ||||
| @@ -141,6 +143,8 @@ std::shared_ptr<fe::graph::Graph> get_sdpa_forward_graph( | ||||
|  | ||||
|   // cuDNN only supports native CUDA graphs for sdpa in 9.6 or above. | ||||
|   if (cudnnGetVersion() < 90600) { | ||||
|     nvtx3::scoped_range r("get_sdpa_forward_graph::graph_building"); | ||||
|  | ||||
|     auto build_status = graph->build(handle, {fe::HeurMode_t::A}); | ||||
|     if (!build_status.is_good()) { | ||||
|       throw std::runtime_error( | ||||
| @@ -437,6 +441,10 @@ void ScaledDotProductAttention::eval_gpu( | ||||
|     const auto& k = copy_unless(is_matrix_contiguous, k_pre); | ||||
|     const auto& v = copy_unless(is_matrix_contiguous, v_pre); | ||||
|  | ||||
|     for (const auto& cp : copies) { | ||||
|       encoder.add_temporary(cp); | ||||
|     } | ||||
|  | ||||
|     int64_t str_oD = 1; | ||||
|     int64_t str_oH = o.shape(3); | ||||
|     int64_t str_oL = o.shape(1) * str_oH; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jagrit Digani
					Jagrit Digani