mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Add more nvtx range for debug
This commit is contained in:
@@ -67,6 +67,8 @@ std::shared_ptr<fe::graph::Graph> get_sdpa_forward_graph(
|
||||
return it->second;
|
||||
}
|
||||
|
||||
nvtx3::scoped_range r("get_sdpa_forward_graph");
|
||||
|
||||
// Set up new graph
|
||||
auto graph = std::make_shared<fe::graph::Graph>();
|
||||
|
||||
@@ -141,6 +143,8 @@ std::shared_ptr<fe::graph::Graph> get_sdpa_forward_graph(
|
||||
|
||||
// cuDNN only supports native CUDA graphs for sdpa in 9.6 or above.
|
||||
if (cudnnGetVersion() < 90600) {
|
||||
nvtx3::scoped_range r("get_sdpa_forward_graph::graph_building");
|
||||
|
||||
auto build_status = graph->build(handle, {fe::HeurMode_t::A});
|
||||
if (!build_status.is_good()) {
|
||||
throw std::runtime_error(
|
||||
@@ -437,6 +441,10 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
|
||||
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
|
||||
|
||||
for (const auto& cp : copies) {
|
||||
encoder.add_temporary(cp);
|
||||
}
|
||||
|
||||
int64_t str_oD = 1;
|
||||
int64_t str_oH = o.shape(3);
|
||||
int64_t str_oL = o.shape(1) * str_oH;
|
||||
|
Reference in New Issue
Block a user